diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 55fe5f1441..956bbb63d8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -79,8 +79,7 @@ jobs: python -m pytest -v tests/python/contrib/test_rpc_server_device.py Windows: - if: ${{ github.repository == 'apache/tvm' }} - runs-on: windows-2019 + runs-on: windows-2016 steps: - uses: actions/checkout@v2 with: @@ -164,4 +163,4 @@ jobs: uses: actions/upload-artifact@v2 with: name: android_camera-debug.apk - path: ./apps/android_camera/app/build/outputs/apk/debug/app-debug.apk \ No newline at end of file + path: ./apps/android_camera/app/build/outputs/apk/debug/app-debug.apk diff --git a/3rdparty/cutlass b/3rdparty/cutlass index c2ee13a0fe..a3bcc6981d 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit c2ee13a0fe99241b0e798ce647acf98e237f1d0c +Subproject commit a3bcc6981d5dad3afb212689e2c7853d1b1ee45d diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a21d22f78..7d483c9828 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -285,6 +285,14 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/printer/*.cc src/support/*.cc src/script/*.cc + src/relax/ir/*.cc + src/relax/op/*.cc + src/relax/analysis/*.cc + src/relax/usmp/*.cc + src/relax/transform/*.cc + src/relax/backend/vm/*.cc + src/relax/backend/task_extraction.cc + src/relax/utils.cc ) tvm_file_glob(GLOB CODEGEN_SRCS @@ -315,6 +323,7 @@ tvm_file_glob(GLOB_RECURSE RELAY_IR_SRCS tvm_file_glob(GLOB_RECURSE RELAY_QNN_SRCS src/relay/qnn/*.cc ) + list(APPEND COMPILER_SRCS ${RELAY_OP_SRCS}) list(APPEND COMPILER_SRCS ${RELAY_PASS_SRCS}) list(APPEND COMPILER_SRCS ${RELAY_BACKEND_SRCS}) @@ -329,6 +338,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS src/runtime/*.cc src/runtime/vm/*.cc src/runtime/minrpc/*.cc + src/runtime/relax_vm/*.cc ) if(BUILD_FOR_HEXAGON) diff --git a/Jenkinsfile b/Jenkinsfile index 78addc9b2c..9d6315f100 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -45,20 +45,19 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2022-10-04T13:17:33.929159 +// Generated at 2022-04-29T08:49:28.997200 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils + // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> ci_lint = 'tlcpack/ci-lint:20220925-060158-71f25b3d6' ci_gpu = 'tlcpack/ci-gpu:20220925-060158-71f25b3d6' ci_cpu = 'tlcpack/ci-cpu:20220925-060158-71f25b3d6' -ci_minimal = 'tlcpack/ci-minimal:20220925-060158-71f25b3d6' -ci_wasm = 'tlcpack/ci-wasm:20220925-060158-71f25b3d6' -ci_i386 = 'tlcpack/ci-i386:20220925-060158-71f25b3d6' -ci_cortexm = 'tlcpack/ci-cortexm:20220925-060158-71f25b3d6' -ci_arm = 'tlcpack/ci-arm:20220925-060158-71f25b3d6' +ci_wasm = 'tlcpack/ci-wasm:v0.72' +ci_i386 = 'tlcpack/ci-i386:v0.75' +ci_qemu = 'tlcpack/ci-qemu:v0.11' +ci_arm = 'tlcpack/ci-arm:v0.08' ci_hexagon = 'tlcpack/ci-hexagon:20220925-060158-71f25b3d6' -ci_riscv = 'tlcpack/ci-riscv:20220925-060158-71f25b3d6' // <--- End of regex-scanned config. // Parameters to allow overriding (in Jenkins UI), the images @@ -66,50 +65,33 @@ ci_riscv = 'tlcpack/ci-riscv:20220925-060158-71f25b3d6' // over default values above. properties([ parameters([ - string(name: 'ci_arm_param', defaultValue: ''), - string(name: 'ci_cortexm_param', defaultValue: ''), - string(name: 'ci_cpu_param', defaultValue: ''), - string(name: 'ci_gpu_param', defaultValue: ''), - string(name: 'ci_hexagon_param', defaultValue: ''), - string(name: 'ci_i386_param', defaultValue: ''), string(name: 'ci_lint_param', defaultValue: ''), - string(name: 'ci_minimal_param', defaultValue: ''), - string(name: 'ci_riscv_param', defaultValue: ''), + string(name: 'ci_cpu_param', defaultValue: ''), + string(name: 'ci_gpu_param', defaultValue: ''), string(name: 'ci_wasm_param', defaultValue: ''), + string(name: 'ci_i386_param', defaultValue: ''), + string(name: 'ci_qemu_param', defaultValue: ''), + string(name: 'ci_arm_param', defaultValue: ''), + string(name: 'ci_hexagon_param', defaultValue: '') ]) ]) -// Placeholders for newly built Docker image names (if rebuild_docker_images -// is used) - built_ci_arm = null; - built_ci_cortexm = null; - built_ci_cpu = null; - built_ci_gpu = null; - built_ci_hexagon = null; - built_ci_i386 = null; - built_ci_lint = null; - built_ci_minimal = null; - built_ci_riscv = null; - built_ci_wasm = null; +// tvm libraries +tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' +tvm_lib = 'build/libtvm.so, ' + tvm_runtime +// LLVM upstream lib +tvm_multilib = 'build/libtvm.so, ' + + 'build/libvta_fsim.so, ' + + tvm_runtime -// Global variable assigned during Sanity Check that holds the sha1 which should be -// merged into the PR in all branches. -upstream_revision = null +tvm_multilib_tsim = 'build/libvta_tsim.so, ' + + tvm_multilib // command to start a docker container -docker_run = 'docker/bash.sh --env CI --env TVM_SHARD_INDEX --env TVM_NUM_SHARDS --env RUN_DISPLAY_URL --env PLATFORM --env SKIP_SLOW_TESTS --env TEST_STEP_NAME' -docker_build = 'docker/build.sh' +docker_run = 'docker/bash.sh' // timeout in minutes -max_time = 180 -rebuild_docker_images = false - -// Filenames for stashing between build and test steps -s3_prefix = "tvm-jenkins-artifacts-prod/tvm/${env.BRANCH_NAME}/${env.BUILD_NUMBER}" +max_time = 240 - -// General note: Jenkins has limits on the size of a method (or top level code) -// that are pretty strict, so most usage of groovy methods in these templates -// are purely to satisfy the JVM def per_exec_ws(folder) { return "workspace/exec_${env.EXECUTOR_NUMBER}/" + folder } @@ -117,71 +99,15 @@ def per_exec_ws(folder) { // initialize source codes def init_git() { checkout scm - - // Add more info about job node sh ( script: './tests/scripts/task_show_node_info.sh', label: 'Show executor node info', ) - - // Determine merge commit to use for all stages - sh ( - script: 'git fetch origin main', - label: 'Fetch upstream', - ) - if (upstream_revision == null) { - upstream_revision = sh( - script: 'git log -1 FETCH_HEAD --format=\'%H\'', - label: 'Determine upstream revision', - returnStdout: true, - ).trim() - } - sh ( - script: "git -c user.name=TVM-Jenkins -c user.email=jenkins@tvm.apache.org merge ${upstream_revision}", - label: 'Merge to origin/main' - ) - - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 timeout 5m git submodule update --init -f --jobs 0 - """, - label: 'Update git submodules', - ) - checkout_trusted_files() -} - -def docker_init(image) { - // Clear out all Docker images that aren't going to be used - sh( - script: """ - set -eux - docker image ls --all - IMAGES=\$(docker image ls --all --format '{{.Repository}}:{{.Tag}} {{.ID}}') - - echo -e "Found images:\\n\$IMAGES" - echo "\$IMAGES" | { grep -vE '${image}' || test \$? = 1; } | { xargs docker rmi || test \$? = 123; } - - docker image ls --all - """, - label: 'Clean old Docker images', - ) - - if (image.contains("amazonaws.com")) { - // If this string is in the image name it's from ECR and needs to be pulled - // with the right credentials - ecr_pull(image) - } else { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 5 docker pull ${image} - """, - label: 'Pull docker image', - ) + retry(5) { + timeout(time: 2, unit: 'MINUTES') { + sh (script: 'git submodule update --init -f', label: 'Update git submodules') + } } } @@ -193,7 +119,7 @@ def should_skip_slow_tests(pr_number) { // Exit code of 1 means run slow tests, exit code of 0 means skip slow tests result = sh ( returnStatus: true, - script: "./ci/scripts/should_run_slow_tests.py --pr '${pr_number}'", + script: "./tests/scripts/should_run_slow_tests.py --pr '${pr_number}'", label: 'Check if CI should run slow tests', ) } @@ -211,448 +137,151 @@ def cancel_previous_build() { } } -def checkout_trusted_files() { - // trust everything from branch builds - if (env.BRANCH_NAME == null || !env.BRANCH_NAME.startsWith('PR-')) { - return; - } - - // trust peoople listed in CONTRIBUTING.md - grep_code = sh( - returnStatus: true, - script: "git show '${upstream_revision}:CONTRIBUTORS.md' | grep '@${env.CHANGE_AUTHOR}'", - label: 'Check if change is from a contributor', - ) - - if (grep_code == 1) { - // Any scripts that run on the bare host and not inside a Docker container - // (especially those that access secrets) should be checked out here so - // only trusted versions are used in CI - sh( - script: "git checkout ${upstream_revision} ci/scripts/.", - label: 'Check out trusted files', - ) - } -} - def should_skip_ci(pr_number) { - if (env.BRANCH_NAME == null || !env.BRANCH_NAME.startsWith('PR-')) { - // never skip CI on build sourced from a branch - return false - } - glob_skip_ci_code = sh ( - returnStatus: true, - script: "./ci/scripts/git_skip_ci_globs.py", - label: 'Check if CI should be skipped due to changed files', - ) - if (glob_skip_ci_code == 0) { - return true - } withCredentials([string( credentialsId: 'tvm-bot-jenkins-reader', - variable: 'GITHUB_TOKEN', + variable: 'TOKEN', )]) { // Exit code of 1 means run full CI (or the script had an error, so run // full CI just in case). Exit code of 0 means skip CI. git_skip_ci_code = sh ( returnStatus: true, - script: "./ci/scripts/git_skip_ci.py --pr '${pr_number}'", + script: "./tests/scripts/git_skip_ci.py --pr '${pr_number}'", label: 'Check if CI should be skipped', ) - } + } return git_skip_ci_code == 0 } -def check_pr(pr_number) { - if (env.BRANCH_NAME == null || !env.BRANCH_NAME.startsWith('PR-')) { - // never skip CI on build sourced from a branch - return false - } - withCredentials([string( - credentialsId: 'tvm-bot-jenkins-reader', - variable: 'GITHUB_TOKEN', - )]) { - sh ( - script: "python3 ci/scripts/check_pr.py --pr ${pr_number}", - label: 'Check PR title and body', - ) - } - -} +cancel_previous_build() -def prepare() { - stage('Prepare') { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/prepare") { +def lint() { +stage('Lint') { + node('CPU') { + // When something is provided in ci_*_param, use it, otherwise default with ci_* + ci_lint = params.ci_lint_param ?: ci_lint + ci_cpu = params.ci_cpu_param ?: ci_cpu + ci_gpu = params.ci_gpu_param ?: ci_gpu + ci_wasm = params.ci_wasm_param ?: ci_wasm + ci_i386 = params.ci_i386_param ?: ci_i386 + ci_qemu = params.ci_qemu_param ?: ci_qemu + ci_arm = params.ci_arm_param ?: ci_arm + ci_hexagon = params.ci_hexagon_param ?: ci_hexagon + + sh (script: """ + echo "Docker images being used in this build:" + echo " ci_lint = ${ci_lint}" + echo " ci_cpu = ${ci_cpu}" + echo " ci_gpu = ${ci_gpu}" + echo " ci_wasm = ${ci_wasm}" + echo " ci_i386 = ${ci_i386}" + echo " ci_qemu = ${ci_qemu}" + echo " ci_arm = ${ci_arm}" + echo " ci_hexagon = ${ci_hexagon}" + """, label: 'Docker image names') + } +} + +stage('Sanity Check') { + timeout(time: max_time, unit: 'MINUTES') { + node('CPU') { + ws(per_exec_ws('tvm/sanity')) { init_git() - - check_pr(env.CHANGE_ID) - - if (env.DETERMINE_DOCKER_IMAGES == 'yes') { - sh( - script: "./ci/scripts/determine_docker_images.py ci_arm=${ci_arm} ci_cortexm=${ci_cortexm} ci_cpu=${ci_cpu} ci_gpu=${ci_gpu} ci_hexagon=${ci_hexagon} ci_i386=${ci_i386} ci_lint=${ci_lint} ci_minimal=${ci_minimal} ci_riscv=${ci_riscv} ci_wasm=${ci_wasm} ", - label: 'Decide whether to use tlcpack or tlcpackstaging for Docker images', - ) - // Pull image names from the results of should_rebuild_docker.py - ci_arm = sh( - script: "cat .docker-image-names/ci_arm", - label: "Find docker image name for ci_arm", - returnStdout: true, - ).trim() - ci_cortexm = sh( - script: "cat .docker-image-names/ci_cortexm", - label: "Find docker image name for ci_cortexm", - returnStdout: true, - ).trim() - ci_cpu = sh( - script: "cat .docker-image-names/ci_cpu", - label: "Find docker image name for ci_cpu", - returnStdout: true, - ).trim() - ci_gpu = sh( - script: "cat .docker-image-names/ci_gpu", - label: "Find docker image name for ci_gpu", - returnStdout: true, - ).trim() - ci_hexagon = sh( - script: "cat .docker-image-names/ci_hexagon", - label: "Find docker image name for ci_hexagon", - returnStdout: true, - ).trim() - ci_i386 = sh( - script: "cat .docker-image-names/ci_i386", - label: "Find docker image name for ci_i386", - returnStdout: true, - ).trim() - ci_lint = sh( - script: "cat .docker-image-names/ci_lint", - label: "Find docker image name for ci_lint", - returnStdout: true, - ).trim() - ci_minimal = sh( - script: "cat .docker-image-names/ci_minimal", - label: "Find docker image name for ci_minimal", - returnStdout: true, - ).trim() - ci_riscv = sh( - script: "cat .docker-image-names/ci_riscv", - label: "Find docker image name for ci_riscv", - returnStdout: true, - ).trim() - ci_wasm = sh( - script: "cat .docker-image-names/ci_wasm", - label: "Find docker image name for ci_wasm", - returnStdout: true, - ).trim() - } - - ci_arm = params.ci_arm_param ?: ci_arm - ci_cortexm = params.ci_cortexm_param ?: ci_cortexm - ci_cpu = params.ci_cpu_param ?: ci_cpu - ci_gpu = params.ci_gpu_param ?: ci_gpu - ci_hexagon = params.ci_hexagon_param ?: ci_hexagon - ci_i386 = params.ci_i386_param ?: ci_i386 - ci_lint = params.ci_lint_param ?: ci_lint - ci_minimal = params.ci_minimal_param ?: ci_minimal - ci_riscv = params.ci_riscv_param ?: ci_riscv - ci_wasm = params.ci_wasm_param ?: ci_wasm - - sh (script: """ - echo "Docker images being used in this build:" - echo " ci_arm = ${ci_arm}" - echo " ci_cortexm = ${ci_cortexm}" - echo " ci_cpu = ${ci_cpu}" - echo " ci_gpu = ${ci_gpu}" - echo " ci_hexagon = ${ci_hexagon}" - echo " ci_i386 = ${ci_i386}" - echo " ci_lint = ${ci_lint}" - echo " ci_minimal = ${ci_minimal}" - echo " ci_riscv = ${ci_riscv}" - echo " ci_wasm = ${ci_wasm}" - """, label: 'Docker image names') - is_docs_only_build = sh ( returnStatus: true, - script: './ci/scripts/git_change_docs.sh', + script: './tests/scripts/git_change_docs.sh', label: 'Check for docs only changes', ) skip_ci = should_skip_ci(env.CHANGE_ID) skip_slow_tests = should_skip_slow_tests(env.CHANGE_ID) - rebuild_docker_images = sh ( - returnStatus: true, - script: './ci/scripts/git_change_docker.sh', - label: 'Check for any docker changes', + sh ( + script: "${docker_run} ${ci_lint} ./tests/scripts/task_lint.sh", + label: 'Run lint', ) - - if (skip_ci) { - // Don't rebuild when skipping CI - rebuild_docker_images = false - } } } } } -def ecr_push(full_name) { - aws_account_id = sh( - returnStdout: true, - script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', - label: 'Get AWS ID' - ).trim() +} - def ecr_name = "${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com/${full_name}" - try { - withEnv([ - "AWS_ACCOUNT_ID=${aws_account_id}", - 'AWS_DEFAULT_REGION=us-west-2', - "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { - sh( - script: ''' - set -eux - aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO - ''', - label: 'Log in to ECR' - ) - sh( - script: """ - set -x - . ci/scripts/retry.sh - docker tag ${full_name} \$AWS_ECR_REPO/${full_name} - retry 5 docker push \$AWS_ECR_REPO/${full_name} - """, - label: 'Upload image to ECR' - ) - } - } finally { - withEnv([ - "AWS_ACCOUNT_ID=${aws_account_id}", - 'AWS_DEFAULT_REGION=us-west-2', - "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { - sh( - script: 'docker logout $AWS_ECR_REPO', - label: 'Clean up login credentials' +// [note: method size] +// This has to be extracted into a method due to JVM limitations on the size of +// a method (so the code can't all be inlined) +lint() + +// Run make. First try to do an incremental make from a previous workspace in hope to +// accelerate the compilation. If something is wrong, clean the workspace and then +// build from scratch. +def make(docker_type, path, make_flag) { + timeout(time: max_time, unit: 'MINUTES') { + try { + cmake_build(docker_type, path, make_flag) + // always run cpp test when build + // sh "${docker_run} ${docker_type} ./tests/scripts/task_cpp_unittest.sh" + } catch (hudson.AbortException ae) { + // script exited due to user abort, directly throw instead of retry + if (ae.getMessage().contains('script returned exit code 143')) { + throw ae + } + echo 'Incremental compilation failed. Fall back to build from scratch' + sh ( + script: "${docker_run} ${docker_type} ./tests/scripts/task_clean.sh ${path}", + label: 'Clear old cmake workspace', ) + cmake_build(docker_type, path, make_flag) + cpp_unittest(docker_type) } } - return ecr_name } -def ecr_pull(full_name) { - aws_account_id = sh( - returnStdout: true, - script: 'aws sts get-caller-identity | grep Account | cut -f4 -d\\"', - label: 'Get AWS ID' - ).trim() +// Specifications to Jenkins "stash" command for use with various pack_ and unpack_ functions. +tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' // use libtvm_runtime.so. +tvm_lib = 'build/libtvm.so, ' + tvm_runtime // use libtvm.so to run the full compiler. +// LLVM upstream lib +tvm_multilib = 'build/libtvm.so, ' + + 'build/libvta_fsim.so, ' + + tvm_runtime - try { - withEnv([ - "AWS_ACCOUNT_ID=${aws_account_id}", - 'AWS_DEFAULT_REGION=us-west-2', - "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { - sh( - script: ''' - set -eux - aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ECR_REPO - ''', - label: 'Log in to ECR' - ) - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 5 docker pull ${full_name} - """, - label: 'Pull image from ECR' - ) - } - } finally { - withEnv([ - "AWS_ACCOUNT_ID=${aws_account_id}", - 'AWS_DEFAULT_REGION=us-west-2', - "AWS_ECR_REPO=${aws_account_id}.dkr.ecr.us-west-2.amazonaws.com"]) { - sh( - script: 'docker logout $AWS_ECR_REPO', - label: 'Clean up login credentials' - ) - } - } +tvm_multilib_tsim = 'build/libvta_tsim.so, ' + + tvm_multilib + +microtvm_tar_gz = 'build/microtvm_template_projects.tar.gz' + +// pack libraries for later use +def pack_lib(name, libs) { + sh (script: """ + echo "Packing ${libs} into ${name}" + echo ${libs} | sed -e 's/,/ /g' | xargs md5sum + """, label: 'Stash libraries and show md5') + stash includes: libs, name: name +} + +// unpack libraries saved before +def unpack_lib(name, libs) { + unstash name + sh (script: """ + echo "Unpacked ${libs} from ${name}" + echo ${libs} | sed -e 's/,/ /g' | xargs md5sum + """, label: 'Unstash libraries and show md5') } -def build_image(image_name) { - hash = sh( - returnStdout: true, - script: 'git log -1 --format=\'%h\'' - ).trim() - def full_name = "${image_name}:${env.BRANCH_NAME}-${hash}-${env.BUILD_NUMBER}" +// compress microtvm template projects and pack the tar. +def pack_microtvm_template_projects(name) { sh( - script: "${docker_build} ${image_name} --spec ${full_name}", - label: 'Build docker image' + script: 'cd build && tar -czvf microtvm_template_projects.tar.gz microtvm_template_projects/', + label: 'Compress microtvm_template_projects' ) - return ecr_push(full_name) + pack_lib(name + '-microtvm-libs', microtvm_tar_gz) } - -def build_docker_images() { - stage('Docker Image Build') { - parallel( - 'ci_arm': { - node('ARM') { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - // We're purposefully not setting the built image here since they - // are not yet being uploaded to tlcpack - // ci_arm = build_image('ci_arm') - built_ci_arm = build_image('ci_arm'); - } - } - }, - 'ci_cortexm': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - // We're purposefully not setting the built image here since they - // are not yet being uploaded to tlcpack - // ci_cortexm = build_image('ci_cortexm') - built_ci_cortexm = build_image('ci_cortexm'); - } - } - }, - 'ci_cpu': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - // We're purposefully not setting the built image here since they - // are not yet being uploaded to tlcpack - // ci_cpu = build_image('ci_cpu') - built_ci_cpu = build_image('ci_cpu'); - } - } - }, - 'ci_gpu': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - // We're purposefully not setting the built image here since they - // are not yet being uploaded to tlcpack - // ci_gpu = build_image('ci_gpu') - built_ci_gpu = build_image('ci_gpu'); - } - } - }, - 'ci_hexagon': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - // We're purposefully not setting the built image here since they - // are not yet being uploaded to tlcpack - // ci_hexagon = build_image('ci_hexagon') - built_ci_hexagon = build_image('ci_hexagon'); - } - } - }, - 'ci_i386': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - // We're purposefully not setting the built image here since they - // are not yet being uploaded to tlcpack - // ci_i386 = build_image('ci_i386') - built_ci_i386 = build_image('ci_i386'); - } - } - }, - 'ci_lint': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - // We're purposefully not setting the built image here since they - // are not yet being uploaded to tlcpack - // ci_lint = build_image('ci_lint') - built_ci_lint = build_image('ci_lint'); - } - } - }, - 'ci_minimal': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - // We're purposefully not setting the built image here since they - // are not yet being uploaded to tlcpack - // ci_minimal = build_image('ci_minimal') - built_ci_minimal = build_image('ci_minimal'); - } - } - }, - 'ci_riscv': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - // We're purposefully not setting the built image here since they - // are not yet being uploaded to tlcpack - // ci_riscv = build_image('ci_riscv') - built_ci_riscv = build_image('ci_riscv'); - } - } - }, - 'ci_wasm': { - node('CPU') { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - // We're purposefully not setting the built image here since they - // are not yet being uploaded to tlcpack - // ci_wasm = build_image('ci_wasm') - built_ci_wasm = build_image('ci_wasm'); - } - } - }, - ) - } -} -def lint() { - stage('Lint') { - parallel( - 'Lint 1 of 2': { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/lint") { - init_git() - docker_init(ci_lint) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'TVM_NUM_SHARDS=2', - 'TEST_STEP_NAME=Lint', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh ( - script: "${docker_run} ${ci_lint} ./tests/scripts/task_lint.sh", - label: 'Run lint', - ) - }) - } - } - } - }, - 'Lint 2 of 2': { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/lint") { - init_git() - docker_init(ci_lint) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'TVM_NUM_SHARDS=2', - 'TEST_STEP_NAME=Lint', - 'TVM_SHARD_INDEX=1', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh ( - script: "${docker_run} ${ci_lint} ./tests/scripts/task_lint.sh", - label: 'Run lint', - ) - }) - } - } - } - }, - ) - } +def unpack_microtvm_template_projects(name) { + unpack_lib(name + '-microtvm-libs', microtvm_tar_gz) + sh( + script: 'cd build && tar -xzvf microtvm_template_projects.tar.gz', + label: 'Unpack microtvm_template_projects' + ) } + def ci_setup(image) { sh ( script: "${docker_run} ${image} ./tests/scripts/task_ci_setup.sh", @@ -676,25 +305,18 @@ def fsim_test(image) { def cmake_build(image, path, make_flag) { sh ( - script: "${docker_run} --env CI_NUM_EXECUTORS ${image} ./tests/scripts/task_build.py --sccache-bucket tvm-sccache-prod", + script: "${docker_run} ${image} ./tests/scripts/task_build.py --sccache-bucket tvm-sccache-prod", label: 'Run cmake build', ) } def cpp_unittest(image) { sh ( - script: "${docker_run} --env CI_NUM_EXECUTORS ${image} ./tests/scripts/task_cpp_unittest.sh", + script: "${docker_run} ${image} ./tests/scripts/task_cpp_unittest.sh", label: 'Build and run C++ tests', ) } -def add_microtvm_permissions() { - sh( - script: 'find build/microtvm_template_projects -type f | grep qemu-hack | xargs chmod +x', - label: 'Add execute permissions for microTVM files', - ) -} - def add_hexagon_permissions() { sh( script: 'find build/hexagon_api_output -type f | xargs chmod +x', @@ -702,304 +324,34 @@ def add_hexagon_permissions() { ) } -// Run make. First try to do an incremental make from a previous workspace in hope to -// accelerate the compilation. If something is wrong, clean the workspace and then -// build from scratch. -def make(docker_type, path, make_flag) { - timeout(time: max_time, unit: 'MINUTES') { - try { - cmake_build(docker_type, path, make_flag) - } catch (hudson.AbortException ae) { - // script exited due to user abort, directly throw instead of retry - if (ae.getMessage().contains('script returned exit code 143')) { - throw ae - } - echo 'Incremental compilation failed. Fall back to build from scratch' - sh ( - script: "${docker_run} ${docker_type} ./tests/scripts/task_clean.sh ${path}", - label: 'Clear old cmake workspace', - ) - cmake_build(docker_type, path, make_flag) - } - } -} - - -def build() { -stage('Build') { - environment { - SKIP_SLOW_TESTS = "${skip_slow_tests}" - } - parallel( - - 'BUILD: GPU': { - if (!skip_ci) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-gpu") { +stage('Build and Test') { + if (is_docs_only_build != 1) { + parallel 'BUILD: GPU': { + node('GPU') { + ws(per_exec_ws('tvm/build-gpu')) { init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - sh "${docker_run} --no-gpu ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" - make("${ci_gpu} --no-gpu", 'build', '-j2') - sh( - script: """ - set -eux - . ci/scripts/retry.sh - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress build/libtvm.so s3://${s3_prefix}/gpu/build/libtvm.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress build/libvta_fsim.so s3://${s3_prefix}/gpu/build/libvta_fsim.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress build/libtvm_runtime.so s3://${s3_prefix}/gpu/build/libtvm_runtime.so - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress build/config.cmake s3://${s3_prefix}/gpu/build/config.cmake - retry 3 aws s3 cp --no-progress build/microtvm_template_projects s3://${s3_prefix}/gpu/build/microtvm_template_projects --recursive - """, - label: 'Upload artifacts to S3', - ) - - - // compiler test - sh "${docker_run} --no-gpu ${ci_gpu} ./tests/scripts/task_config_build_gpu_other.sh build2" - make("${ci_gpu} --no-gpu", 'build2', '-j2') - sh( - script: """ - set -eux - . ci/scripts/retry.sh - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress build/libtvm.so s3://${s3_prefix}/gpu2/build/libtvm.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress build/libvta_fsim.so s3://${s3_prefix}/gpu2/build/libvta_fsim.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress build/libtvm_runtime.so s3://${s3_prefix}/gpu2/build/libtvm_runtime.so - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress build/config.cmake s3://${s3_prefix}/gpu2/build/config.cmake - """, - label: 'Upload artifacts to S3', - ) - } + sh "${docker_run} ${ci_gpu} nvidia-smi" + sh "${docker_run} ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" + make("${ci_gpu}", 'build', '-j2') + sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration_gpuonly.sh" } } - } else { - Utils.markStageSkippedForConditional('BUILD: GPU') - } - }, - - 'BUILD: CPU': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-cpu") { + }, + 'BUILD: CPU': { + node('CPU') { + ws(per_exec_ws('tvm/build-cpu')) { init_git() - docker_init(ci_cpu) - timeout(time: max_time, unit: 'MINUTES') { - sh ( - script: "${docker_run} ${ci_cpu} ./tests/scripts/task_config_build_cpu.sh build", - label: 'Create CPU cmake config', - ) + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_config_build_cpu.sh build" make(ci_cpu, 'build', '-j2') - sh( - script: """ - set -eux - . ci/scripts/retry.sh - md5sum build/libvta_tsim.so - retry 3 aws s3 cp --no-progress build/libvta_tsim.so s3://${s3_prefix}/cpu/build/libvta_tsim.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress build/libtvm.so s3://${s3_prefix}/cpu/build/libtvm.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress build/libvta_fsim.so s3://${s3_prefix}/cpu/build/libvta_fsim.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress build/libtvm_runtime.so s3://${s3_prefix}/cpu/build/libtvm_runtime.so - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress build/config.cmake s3://${s3_prefix}/cpu/build/config.cmake - """, - label: 'Upload artifacts to S3', - ) - - ci_setup(ci_cpu) - // sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" - // TODO(@jroesch): need to resolve CI issue will turn back on in follow up patch - sh (script: "${docker_run} ${ci_cpu} ./tests/scripts/task_rust.sh", label: 'Rust build and test') - } - } - } - } else { - Utils.markStageSkippedForConditional('BUILD: CPU') - } - }, - - 'BUILD: CPU MINIMAL': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-cpu-minimal") { - init_git() - docker_init(ci_minimal) - timeout(time: max_time, unit: 'MINUTES') { - sh ( - script: "${docker_run} ${ci_minimal} ./tests/scripts/task_config_build_minimal.sh build", - label: 'Create CPU minimal cmake config', - ) - make(ci_minimal, 'build', '-j2') - sh( - script: """ - set -eux - . ci/scripts/retry.sh - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress build/libtvm.so s3://${s3_prefix}/cpu-minimal/build/libtvm.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress build/libtvm_runtime.so s3://${s3_prefix}/cpu-minimal/build/libtvm_runtime.so - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress build/config.cmake s3://${s3_prefix}/cpu-minimal/build/config.cmake - """, - label: 'Upload artifacts to S3', - ) - } + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_integration.sh" } } - } else { - Utils.markStageSkippedForConditional('BUILD: CPU MINIMAL') - } - }, - - 'BUILD: WASM': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-wasm") { + }, + 'BUILD: Hexagon': { + node('CPU') { + ws(per_exec_ws('tvm/build-hexagon')) { init_git() - docker_init(ci_wasm) - timeout(time: max_time, unit: 'MINUTES') { - sh ( - script: "${docker_run} ${ci_wasm} ./tests/scripts/task_config_build_wasm.sh build", - label: 'Create WASM cmake config', - ) - make(ci_wasm, 'build', '-j2') - cpp_unittest(ci_wasm) - ci_setup(ci_wasm) sh ( - script: "${docker_run} ${ci_wasm} ./tests/scripts/task_web_wasm.sh", - label: 'Run WASM lint and tests', - ) - } - } - } - } else { - Utils.markStageSkippedForConditional('BUILD: WASM') - } - }, - - 'BUILD: i386': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-i386") { - init_git() - docker_init(ci_i386) - timeout(time: max_time, unit: 'MINUTES') { - sh ( - script: "${docker_run} ${ci_i386} ./tests/scripts/task_config_build_i386.sh build", - label: 'Create i386 cmake config', - ) - make(ci_i386, 'build', '-j2') - sh( - script: """ - set -eux - . ci/scripts/retry.sh - md5sum build/libvta_tsim.so - retry 3 aws s3 cp --no-progress build/libvta_tsim.so s3://${s3_prefix}/i386/build/libvta_tsim.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress build/libtvm.so s3://${s3_prefix}/i386/build/libtvm.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress build/libvta_fsim.so s3://${s3_prefix}/i386/build/libvta_fsim.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress build/libtvm_runtime.so s3://${s3_prefix}/i386/build/libtvm_runtime.so - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress build/config.cmake s3://${s3_prefix}/i386/build/config.cmake - """, - label: 'Upload artifacts to S3', - ) - } - } - } - } else { - Utils.markStageSkippedForConditional('BUILD: i386') - } - }, - - 'BUILD: arm': { - if (!skip_ci && is_docs_only_build != 1) { - node('ARM-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-arm") { - init_git() - docker_init(ci_arm) - timeout(time: max_time, unit: 'MINUTES') { - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_config_build_arm.sh build", - label: 'Create ARM cmake config', - ) - make(ci_arm, 'build', '-j4') - sh( - script: """ - set -eux - . ci/scripts/retry.sh - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress build/libtvm.so s3://${s3_prefix}/arm/build/libtvm.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress build/libvta_fsim.so s3://${s3_prefix}/arm/build/libvta_fsim.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress build/libtvm_runtime.so s3://${s3_prefix}/arm/build/libtvm_runtime.so - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress build/config.cmake s3://${s3_prefix}/arm/build/config.cmake - """, - label: 'Upload artifacts to S3', - ) - } - } - } - } else { - Utils.markStageSkippedForConditional('BUILD: arm') - } - }, - - 'BUILD: Cortex-M': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-cortexm") { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_config_build_cortexm.sh build", - label: 'Create Cortex-M cmake config', - ) - make(ci_cortexm, 'build', '-j2') - sh( - script: """ - set -eux - . ci/scripts/retry.sh - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress build/libtvm.so s3://${s3_prefix}/cortexm/build/libtvm.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress build/libtvm_runtime.so s3://${s3_prefix}/cortexm/build/libtvm_runtime.so - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress build/config.cmake s3://${s3_prefix}/cortexm/build/config.cmake - retry 3 aws s3 cp --no-progress build/microtvm_template_projects s3://${s3_prefix}/cortexm/build/microtvm_template_projects --recursive - """, - label: 'Upload artifacts to S3', - ) - } - } - } - } else { - Utils.markStageSkippedForConditional('BUILD: Cortex-M') - } - }, - - 'BUILD: Hexagon': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-hexagon") { - init_git() - docker_init(ci_hexagon) - timeout(time: max_time, unit: 'MINUTES') { - sh ( script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_config_build_hexagon.sh build", label: 'Create Hexagon cmake config', ) @@ -1008,3541 +360,16 @@ stage('Build') { script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_build_hexagon_api.sh", label: 'Build Hexagon API', ) - sh( - script: """ - set -eux - . ci/scripts/retry.sh - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress build/libtvm.so s3://${s3_prefix}/hexagon/build/libtvm.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress build/libtvm_runtime.so s3://${s3_prefix}/hexagon/build/libtvm_runtime.so - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress build/config.cmake s3://${s3_prefix}/hexagon/build/config.cmake - retry 3 aws s3 cp --no-progress build/hexagon_api_output s3://${s3_prefix}/hexagon/build/hexagon_api_output --recursive - """, - label: 'Upload artifacts to S3', - ) - } - } - } - } else { - Utils.markStageSkippedForConditional('BUILD: Hexagon') - } - }, - - 'BUILD: RISC-V': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/build-riscv") { - init_git() - docker_init(ci_riscv) - timeout(time: max_time, unit: 'MINUTES') { - sh ( - script: "${docker_run} ${ci_riscv} ./tests/scripts/task_config_build_riscv.sh build", - label: 'Create RISC-V cmake config', - ) - make(ci_riscv, 'build', '-j2') - sh( - script: """ - set -eux - . ci/scripts/retry.sh - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress build/libtvm.so s3://${s3_prefix}/riscv/build/libtvm.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress build/libtvm_runtime.so s3://${s3_prefix}/riscv/build/libtvm_runtime.so - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress build/config.cmake s3://${s3_prefix}/riscv/build/config.cmake - retry 3 aws s3 cp --no-progress build/microtvm_template_projects s3://${s3_prefix}/riscv/build/microtvm_template_projects --recursive - """, - label: 'Upload artifacts to S3', - ) - } - } - } - } else { - Utils.markStageSkippedForConditional('BUILD: RISC-V') - } - }, - - ) -} -} - -// We have to do this whacky split of the code from where it's used since the -// JVM limits method length to 64k and we easily exceed that with all this -// autogenerated code. This makes it so each test step is in its own method so -// that each individual method isn't too big. - -def shard_run_unittest_GPU_1_of_3() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=unittest: GPU', - 'TVM_NUM_SHARDS=3', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu2/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu2/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu2/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu2/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - cpp_unittest(ci_gpu) - - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - cpp_unittest(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest_gpuonly.sh", - label: 'Run Python GPU unit tests', - ) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration_gpuonly.sh", - label: 'Run Python GPU integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/unittest_GPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('unittest: GPU 1 of 3') - } -} - -def shard_run_unittest_GPU_2_of_3() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=unittest: GPU', - 'TVM_NUM_SHARDS=3', - 'TVM_SHARD_INDEX=1', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_java_unittest.sh", - label: 'Run Java unit tests', - ) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest_gpuonly.sh", - label: 'Run Python GPU unit tests', - ) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration_gpuonly.sh", - label: 'Run Python GPU integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/unittest_GPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('unittest: GPU 2 of 3') - } -} - -def shard_run_unittest_GPU_3_of_3() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=unittest: GPU', - 'TVM_NUM_SHARDS=3', - 'TVM_SHARD_INDEX=2', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest_gpuonly.sh", - label: 'Run Python GPU unit tests', - ) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration_gpuonly.sh", - label: 'Run Python GPU integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/unittest_GPU --recursive - """, - label: 'Upload JUnits to S3', + add_hexagon_permissions() + sh ( + script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", + label: 'Run Hexagon tests', ) - - junit 'build/pytest-results/*.xml' } } } } else { - Utils.markStageSkippedForConditional('unittest: GPU 3 of 3') + Utils.markStageSkippedForConditional('BUILD: CPU') } } - -def shard_run_integration_CPU_1_of_4() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-cpu") { - try { - init_git() - docker_init(ci_cpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cpu', - 'TEST_STEP_NAME=integration: CPU', - 'TVM_NUM_SHARDS=4', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libvta_tsim.so build/libvta_tsim.so - md5sum build/libvta_tsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_cpu) - sh ( - script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_integration.sh", - label: 'Run CPU integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/integration_CPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('integration: CPU 1 of 4') - } -} - -def shard_run_integration_CPU_2_of_4() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-cpu") { - try { - init_git() - docker_init(ci_cpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cpu', - 'TEST_STEP_NAME=integration: CPU', - 'TVM_NUM_SHARDS=4', - 'TVM_SHARD_INDEX=1', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libvta_tsim.so build/libvta_tsim.so - md5sum build/libvta_tsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_cpu) - sh ( - script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_integration.sh", - label: 'Run CPU integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/integration_CPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('integration: CPU 2 of 4') - } -} - -def shard_run_integration_CPU_3_of_4() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-cpu") { - try { - init_git() - docker_init(ci_cpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cpu', - 'TEST_STEP_NAME=integration: CPU', - 'TVM_NUM_SHARDS=4', - 'TVM_SHARD_INDEX=2', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libvta_tsim.so build/libvta_tsim.so - md5sum build/libvta_tsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_cpu) - sh ( - script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_integration.sh", - label: 'Run CPU integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/integration_CPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('integration: CPU 3 of 4') - } -} - -def shard_run_integration_CPU_4_of_4() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-cpu") { - try { - init_git() - docker_init(ci_cpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cpu', - 'TEST_STEP_NAME=integration: CPU', - 'TVM_NUM_SHARDS=4', - 'TVM_SHARD_INDEX=3', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libvta_tsim.so build/libvta_tsim.so - md5sum build/libvta_tsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_cpu) - sh ( - script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_integration.sh", - label: 'Run CPU integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/integration_CPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('integration: CPU 4 of 4') - } -} - - -def shard_run_python_i386_1_of_3() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-i386") { - try { - init_git() - docker_init(ci_i386) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=i386', - 'TEST_STEP_NAME=python: i386', - 'TVM_NUM_SHARDS=3', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_i386) - cpp_unittest(ci_i386) - python_unittest(ci_i386) - sh ( - script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", - label: 'Run i386 integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/python_i386 --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('python: i386 1 of 3') - } -} - -def shard_run_python_i386_2_of_3() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-i386") { - try { - init_git() - docker_init(ci_i386) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=i386', - 'TEST_STEP_NAME=python: i386', - 'TVM_NUM_SHARDS=3', - 'TVM_SHARD_INDEX=1', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_i386) - python_unittest(ci_i386) - sh ( - script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", - label: 'Run i386 integration tests', - ) - fsim_test(ci_i386) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/python_i386 --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('python: i386 2 of 3') - } -} - -def shard_run_python_i386_3_of_3() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/integration-python-i386") { - try { - init_git() - docker_init(ci_i386) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=i386', - 'TEST_STEP_NAME=python: i386', - 'TVM_NUM_SHARDS=3', - 'TVM_SHARD_INDEX=2', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/i386/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_i386) - python_unittest(ci_i386) - sh ( - script: "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh", - label: 'Run i386 integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/python_i386 --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('python: i386 3 of 3') - } -} - - -def shard_run_test_Hexagon_1_of_8() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { - try { - init_git() - docker_init(ci_hexagon) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=hexagon', - 'TEST_STEP_NAME=test: Hexagon', - 'TVM_NUM_SHARDS=8', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/hexagon_api_output build/hexagon_api_output --recursive - """, - label: 'Download artifacts from S3', - ) - - add_hexagon_permissions() - ci_setup(ci_hexagon) - cpp_unittest(ci_hexagon) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", - label: 'Run Hexagon tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Hexagon --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Hexagon 1 of 8') - } -} - -def shard_run_test_Hexagon_2_of_8() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { - try { - init_git() - docker_init(ci_hexagon) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=hexagon', - 'TEST_STEP_NAME=test: Hexagon', - 'TVM_NUM_SHARDS=8', - 'TVM_SHARD_INDEX=1', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/hexagon_api_output build/hexagon_api_output --recursive - """, - label: 'Download artifacts from S3', - ) - - add_hexagon_permissions() - ci_setup(ci_hexagon) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", - label: 'Run Hexagon tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Hexagon --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Hexagon 2 of 8') - } -} - -def shard_run_test_Hexagon_3_of_8() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { - try { - init_git() - docker_init(ci_hexagon) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=hexagon', - 'TEST_STEP_NAME=test: Hexagon', - 'TVM_NUM_SHARDS=8', - 'TVM_SHARD_INDEX=2', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/hexagon_api_output build/hexagon_api_output --recursive - """, - label: 'Download artifacts from S3', - ) - - add_hexagon_permissions() - ci_setup(ci_hexagon) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", - label: 'Run Hexagon tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Hexagon --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Hexagon 3 of 8') - } -} - -def shard_run_test_Hexagon_4_of_8() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { - try { - init_git() - docker_init(ci_hexagon) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=hexagon', - 'TEST_STEP_NAME=test: Hexagon', - 'TVM_NUM_SHARDS=8', - 'TVM_SHARD_INDEX=3', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/hexagon_api_output build/hexagon_api_output --recursive - """, - label: 'Download artifacts from S3', - ) - - add_hexagon_permissions() - ci_setup(ci_hexagon) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", - label: 'Run Hexagon tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Hexagon --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Hexagon 4 of 8') - } -} - -def shard_run_test_Hexagon_5_of_8() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { - try { - init_git() - docker_init(ci_hexagon) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=hexagon', - 'TEST_STEP_NAME=test: Hexagon', - 'TVM_NUM_SHARDS=8', - 'TVM_SHARD_INDEX=4', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/hexagon_api_output build/hexagon_api_output --recursive - """, - label: 'Download artifacts from S3', - ) - - add_hexagon_permissions() - ci_setup(ci_hexagon) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", - label: 'Run Hexagon tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Hexagon --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Hexagon 5 of 8') - } -} - -def shard_run_test_Hexagon_6_of_8() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { - try { - init_git() - docker_init(ci_hexagon) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=hexagon', - 'TEST_STEP_NAME=test: Hexagon', - 'TVM_NUM_SHARDS=8', - 'TVM_SHARD_INDEX=5', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/hexagon_api_output build/hexagon_api_output --recursive - """, - label: 'Download artifacts from S3', - ) - - add_hexagon_permissions() - ci_setup(ci_hexagon) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", - label: 'Run Hexagon tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Hexagon --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Hexagon 6 of 8') - } -} - -def shard_run_test_Hexagon_7_of_8() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { - try { - init_git() - docker_init(ci_hexagon) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=hexagon', - 'TEST_STEP_NAME=test: Hexagon', - 'TVM_NUM_SHARDS=8', - 'TVM_SHARD_INDEX=6', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/hexagon_api_output build/hexagon_api_output --recursive - """, - label: 'Download artifacts from S3', - ) - - add_hexagon_permissions() - ci_setup(ci_hexagon) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", - label: 'Run Hexagon tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Hexagon --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Hexagon 7 of 8') - } -} - -def shard_run_test_Hexagon_8_of_8() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-hexagon") { - try { - init_git() - docker_init(ci_hexagon) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=hexagon', - 'TEST_STEP_NAME=test: Hexagon', - 'TVM_NUM_SHARDS=8', - 'TVM_SHARD_INDEX=7', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/hexagon/build/hexagon_api_output build/hexagon_api_output --recursive - """, - label: 'Download artifacts from S3', - ) - - add_hexagon_permissions() - ci_setup(ci_hexagon) - sh ( - script: "${docker_run} ${ci_hexagon} ./tests/scripts/task_python_hexagon.sh", - label: 'Run Hexagon tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Hexagon --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Hexagon 8 of 8') - } -} - - -def shard_run_integration_aarch64_1_of_4() { - if (!skip_ci && is_docs_only_build != 1) { - node('ARM-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { - try { - init_git() - docker_init(ci_arm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=arm', - 'TEST_STEP_NAME=integration: aarch64', - 'TVM_NUM_SHARDS=4', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_arm) - python_unittest(ci_arm) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_integration.sh", - label: 'Run CPU integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/integration_aarch64 --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('integration: aarch64 1 of 4') - } -} - -def shard_run_integration_aarch64_2_of_4() { - if (!skip_ci && is_docs_only_build != 1) { - node('ARM-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { - try { - init_git() - docker_init(ci_arm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=arm', - 'TEST_STEP_NAME=integration: aarch64', - 'TVM_NUM_SHARDS=4', - 'TVM_SHARD_INDEX=1', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_arm) - python_unittest(ci_arm) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_integration.sh", - label: 'Run CPU integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/integration_aarch64 --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('integration: aarch64 2 of 4') - } -} - -def shard_run_integration_aarch64_3_of_4() { - if (!skip_ci && is_docs_only_build != 1) { - node('ARM-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { - try { - init_git() - docker_init(ci_arm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=arm', - 'TEST_STEP_NAME=integration: aarch64', - 'TVM_NUM_SHARDS=4', - 'TVM_SHARD_INDEX=2', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_arm) - python_unittest(ci_arm) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_integration.sh", - label: 'Run CPU integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/integration_aarch64 --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('integration: aarch64 3 of 4') - } -} - -def shard_run_integration_aarch64_4_of_4() { - if (!skip_ci && is_docs_only_build != 1) { - node('ARM-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { - try { - init_git() - docker_init(ci_arm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=arm', - 'TEST_STEP_NAME=integration: aarch64', - 'TVM_NUM_SHARDS=4', - 'TVM_SHARD_INDEX=3', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_arm) - python_unittest(ci_arm) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_integration.sh", - label: 'Run CPU integration tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/integration_aarch64 --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('integration: aarch64 4 of 4') - } -} - - -def shard_run_topi_GPU_1_of_3() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/topi-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=topi: GPU', - 'TVM_NUM_SHARDS=3', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_topi.sh", - label: 'Run TOPI tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/topi_GPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('topi: GPU 1 of 3') - } -} - -def shard_run_topi_GPU_2_of_3() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/topi-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=topi: GPU', - 'TVM_NUM_SHARDS=3', - 'TVM_SHARD_INDEX=1', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_topi.sh", - label: 'Run TOPI tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/topi_GPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('topi: GPU 2 of 3') - } -} - -def shard_run_topi_GPU_3_of_3() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/topi-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=topi: GPU', - 'TVM_NUM_SHARDS=3', - 'TVM_SHARD_INDEX=2', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_topi.sh", - label: 'Run TOPI tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/topi_GPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('topi: GPU 3 of 3') - } -} - - -def shard_run_frontend_GPU_1_of_6() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=frontend: GPU', - 'TVM_NUM_SHARDS=6', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_frontend.sh", - label: 'Run Python frontend tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/frontend_GPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('frontend: GPU 1 of 6') - } -} - -def shard_run_frontend_GPU_2_of_6() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=frontend: GPU', - 'TVM_NUM_SHARDS=6', - 'TVM_SHARD_INDEX=1', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_frontend.sh", - label: 'Run Python frontend tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/frontend_GPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('frontend: GPU 2 of 6') - } -} - -def shard_run_frontend_GPU_3_of_6() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=frontend: GPU', - 'TVM_NUM_SHARDS=6', - 'TVM_SHARD_INDEX=2', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_frontend.sh", - label: 'Run Python frontend tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/frontend_GPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('frontend: GPU 3 of 6') - } -} - -def shard_run_frontend_GPU_4_of_6() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=frontend: GPU', - 'TVM_NUM_SHARDS=6', - 'TVM_SHARD_INDEX=3', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_frontend.sh", - label: 'Run Python frontend tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/frontend_GPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('frontend: GPU 4 of 6') - } -} - -def shard_run_frontend_GPU_5_of_6() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=frontend: GPU', - 'TVM_NUM_SHARDS=6', - 'TVM_SHARD_INDEX=4', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_frontend.sh", - label: 'Run Python frontend tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/frontend_GPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('frontend: GPU 5 of 6') - } -} - -def shard_run_frontend_GPU_6_of_6() { - if (!skip_ci && is_docs_only_build != 1) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-gpu") { - try { - init_git() - docker_init(ci_gpu) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=gpu', - 'TEST_STEP_NAME=frontend: GPU', - 'TVM_NUM_SHARDS=6', - 'TVM_SHARD_INDEX=5', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_frontend.sh", - label: 'Run Python frontend tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/frontend_GPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('frontend: GPU 6 of 6') - } -} - - -def shard_run_topi_aarch64_1_of_2() { - if (!skip_ci && is_docs_only_build != 1) { - node('ARM-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { - try { - init_git() - docker_init(ci_arm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=arm', - 'TEST_STEP_NAME=topi: aarch64', - 'TVM_NUM_SHARDS=2', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_arm) - cpp_unittest(ci_arm) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_arm_compute_library.sh", - label: 'Run test_arm_compute_lib test', - ) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_topi.sh", - label: 'Run TOPI tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/topi_aarch64 --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('topi: aarch64 1 of 2') - } -} - -def shard_run_topi_aarch64_2_of_2() { - if (!skip_ci && is_docs_only_build != 1) { - node('ARM-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-arm") { - try { - init_git() - docker_init(ci_arm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=arm', - 'TEST_STEP_NAME=topi: aarch64', - 'TVM_NUM_SHARDS=2', - 'TVM_SHARD_INDEX=1', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_arm) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_arm_compute_library.sh", - label: 'Run test_arm_compute_lib test', - ) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_topi.sh", - label: 'Run TOPI tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/topi_aarch64 --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('topi: aarch64 2 of 2') - } -} - - -def shard_run_frontend_aarch64_1_of_2() { - if (!skip_ci && is_docs_only_build != 1) { - node('ARM-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-arm") { - try { - init_git() - docker_init(ci_arm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=arm', - 'TEST_STEP_NAME=frontend: aarch64', - 'TVM_NUM_SHARDS=2', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_arm) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_frontend_cpu.sh", - label: 'Run Python frontend tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/frontend_aarch64 --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('frontend: aarch64 1 of 2') - } -} - -def shard_run_frontend_aarch64_2_of_2() { - if (!skip_ci && is_docs_only_build != 1) { - node('ARM-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-arm") { - try { - init_git() - docker_init(ci_arm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=arm', - 'TEST_STEP_NAME=frontend: aarch64', - 'TVM_NUM_SHARDS=2', - 'TVM_SHARD_INDEX=1', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/arm/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_arm) - sh ( - script: "${docker_run} ${ci_arm} ./tests/scripts/task_python_frontend_cpu.sh", - label: 'Run Python frontend tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/frontend_aarch64 --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('frontend: aarch64 2 of 2') - } -} - - -def shard_run_test_Cortex_M_1_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - cpp_unittest(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_demo_microtvm.sh", - label: 'Run microTVM demos', - ) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 1 of 12') - } -} - -def shard_run_test_Cortex_M_2_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=1', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 2 of 12') - } -} - -def shard_run_test_Cortex_M_3_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=2', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 3 of 12') - } -} - -def shard_run_test_Cortex_M_4_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=3', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 4 of 12') - } -} - -def shard_run_test_Cortex_M_5_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=4', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 5 of 12') - } -} - -def shard_run_test_Cortex_M_6_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=5', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 6 of 12') - } -} - -def shard_run_test_Cortex_M_7_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=6', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 7 of 12') - } -} - -def shard_run_test_Cortex_M_8_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=7', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 8 of 12') - } -} - -def shard_run_test_Cortex_M_9_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=8', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 9 of 12') - } -} - -def shard_run_test_Cortex_M_10_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=9', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 10 of 12') - } -} - -def shard_run_test_Cortex_M_11_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=10', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 11 of 12') - } -} - -def shard_run_test_Cortex_M_12_of_12() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-cortexm") { - try { - init_git() - docker_init(ci_cortexm) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=cortexm', - 'TEST_STEP_NAME=test: Cortex-M', - 'TVM_NUM_SHARDS=12', - 'TVM_SHARD_INDEX=11', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cortexm/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_cortexm) - sh ( - script: "${docker_run} ${ci_cortexm} ./tests/scripts/task_python_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_Cortex_M --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: Cortex-M 12 of 12') - } -} - - -def shard_run_test_RISC_V_1_of_1() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/test-riscv") { - try { - init_git() - docker_init(ci_riscv) - timeout(time: max_time, unit: 'MINUTES') { - withEnv([ - 'PLATFORM=riscv', - 'TEST_STEP_NAME=test: RISC-V', - 'TVM_NUM_SHARDS=1', - 'TVM_SHARD_INDEX=0', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/riscv/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/riscv/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/riscv/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/riscv/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - ci_setup(ci_riscv) - cpp_unittest(ci_cortexm) - sh ( - script: "${docker_run} ${ci_riscv} ./tests/scripts/task_riscv_microtvm.sh", - label: 'Run microTVM tests', - ) - }) - } - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/test_RISC_V --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } else { - Utils.markStageSkippedForConditional('test: RISC-V 1 of 1') - } -} - - -def run_unittest_minimal() { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-cpu-minimal") { - timeout(time: max_time, unit: 'MINUTES') { - try { - init_git() - docker_init(ci_minimal) - withEnv(['PLATFORM=minimal'], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu-minimal/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu-minimal/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu-minimal/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - cpp_unittest(ci_minimal) - python_unittest(ci_minimal) - }) - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/unittest_CPU_MINIMAL --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } - } else { - Utils.markStageSkippedForConditional('unittest: CPU MINIMAL') - } -} - -def test() { -stage('Test') { - environment { - SKIP_SLOW_TESTS = "${skip_slow_tests}" - } - parallel( - 'unittest: GPU 1 of 3': { - shard_run_unittest_GPU_1_of_3() - }, - 'unittest: GPU 2 of 3': { - shard_run_unittest_GPU_2_of_3() - }, - 'unittest: GPU 3 of 3': { - shard_run_unittest_GPU_3_of_3() - }, - 'integration: CPU 1 of 4': { - shard_run_integration_CPU_1_of_4() - }, - 'integration: CPU 2 of 4': { - shard_run_integration_CPU_2_of_4() - }, - 'integration: CPU 3 of 4': { - shard_run_integration_CPU_3_of_4() - }, - 'integration: CPU 4 of 4': { - shard_run_integration_CPU_4_of_4() - }, - 'python: i386 1 of 3': { - shard_run_python_i386_1_of_3() - }, - 'python: i386 2 of 3': { - shard_run_python_i386_2_of_3() - }, - 'python: i386 3 of 3': { - shard_run_python_i386_3_of_3() - }, - 'test: Hexagon 1 of 8': { - shard_run_test_Hexagon_1_of_8() - }, - 'test: Hexagon 2 of 8': { - shard_run_test_Hexagon_2_of_8() - }, - 'test: Hexagon 3 of 8': { - shard_run_test_Hexagon_3_of_8() - }, - 'test: Hexagon 4 of 8': { - shard_run_test_Hexagon_4_of_8() - }, - 'test: Hexagon 5 of 8': { - shard_run_test_Hexagon_5_of_8() - }, - 'test: Hexagon 6 of 8': { - shard_run_test_Hexagon_6_of_8() - }, - 'test: Hexagon 7 of 8': { - shard_run_test_Hexagon_7_of_8() - }, - 'test: Hexagon 8 of 8': { - shard_run_test_Hexagon_8_of_8() - }, - 'integration: aarch64 1 of 4': { - shard_run_integration_aarch64_1_of_4() - }, - 'integration: aarch64 2 of 4': { - shard_run_integration_aarch64_2_of_4() - }, - 'integration: aarch64 3 of 4': { - shard_run_integration_aarch64_3_of_4() - }, - 'integration: aarch64 4 of 4': { - shard_run_integration_aarch64_4_of_4() - }, - 'topi: GPU 1 of 3': { - shard_run_topi_GPU_1_of_3() - }, - 'topi: GPU 2 of 3': { - shard_run_topi_GPU_2_of_3() - }, - 'topi: GPU 3 of 3': { - shard_run_topi_GPU_3_of_3() - }, - 'frontend: GPU 1 of 6': { - shard_run_frontend_GPU_1_of_6() - }, - 'frontend: GPU 2 of 6': { - shard_run_frontend_GPU_2_of_6() - }, - 'frontend: GPU 3 of 6': { - shard_run_frontend_GPU_3_of_6() - }, - 'frontend: GPU 4 of 6': { - shard_run_frontend_GPU_4_of_6() - }, - 'frontend: GPU 5 of 6': { - shard_run_frontend_GPU_5_of_6() - }, - 'frontend: GPU 6 of 6': { - shard_run_frontend_GPU_6_of_6() - }, - 'topi: aarch64 1 of 2': { - shard_run_topi_aarch64_1_of_2() - }, - 'topi: aarch64 2 of 2': { - shard_run_topi_aarch64_2_of_2() - }, - 'frontend: aarch64 1 of 2': { - shard_run_frontend_aarch64_1_of_2() - }, - 'frontend: aarch64 2 of 2': { - shard_run_frontend_aarch64_2_of_2() - }, - 'test: Cortex-M 1 of 12': { - shard_run_test_Cortex_M_1_of_12() - }, - 'test: Cortex-M 2 of 12': { - shard_run_test_Cortex_M_2_of_12() - }, - 'test: Cortex-M 3 of 12': { - shard_run_test_Cortex_M_3_of_12() - }, - 'test: Cortex-M 4 of 12': { - shard_run_test_Cortex_M_4_of_12() - }, - 'test: Cortex-M 5 of 12': { - shard_run_test_Cortex_M_5_of_12() - }, - 'test: Cortex-M 6 of 12': { - shard_run_test_Cortex_M_6_of_12() - }, - 'test: Cortex-M 7 of 12': { - shard_run_test_Cortex_M_7_of_12() - }, - 'test: Cortex-M 8 of 12': { - shard_run_test_Cortex_M_8_of_12() - }, - 'test: Cortex-M 9 of 12': { - shard_run_test_Cortex_M_9_of_12() - }, - 'test: Cortex-M 10 of 12': { - shard_run_test_Cortex_M_10_of_12() - }, - 'test: Cortex-M 11 of 12': { - shard_run_test_Cortex_M_11_of_12() - }, - 'test: Cortex-M 12 of 12': { - shard_run_test_Cortex_M_12_of_12() - }, - 'test: RISC-V 1 of 1': { - shard_run_test_RISC_V_1_of_1() - }, - 'unittest: CPU MINIMAL': { - run_unittest_minimal() - }, - 'unittest: CPU': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/ut-python-cpu") { - timeout(time: max_time, unit: 'MINUTES') { - try { - init_git() - docker_init(ci_cpu) - withEnv(['PLATFORM=cpu', - 'TEST_STEP_NAME=unittest: CPU', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libvta_tsim.so build/libvta_tsim.so - md5sum build/libvta_tsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_cpu) - cpp_unittest(ci_cpu) - python_unittest(ci_cpu) - fsim_test(ci_cpu) - sh ( - script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_tsim.sh", - label: 'Run VTA tests in TSIM', - ) - }) - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/unittest_CPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } - } else { - Utils.markStageSkippedForConditional('unittest: CPU') - } - }, - 'frontend: CPU': { - if (!skip_ci && is_docs_only_build != 1) { - node('CPU-SMALL') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/frontend-python-cpu") { - timeout(time: max_time, unit: 'MINUTES') { - try { - init_git() - docker_init(ci_cpu) - withEnv(['PLATFORM=cpu', - 'TEST_STEP_NAME=frontend: CPU', - "SKIP_SLOW_TESTS=${skip_slow_tests}"], { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/cpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - """, - label: 'Download artifacts from S3', - ) - - ci_setup(ci_cpu) - sh ( - script: "${docker_run} ${ci_cpu} ./tests/scripts/task_python_frontend_cpu.sh", - label: 'Run Python frontend tests', - ) - }) - } finally { - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress build/pytest-results s3://${s3_prefix}/pytest-results/frontend_CPU --recursive - """, - label: 'Upload JUnits to S3', - ) - - junit 'build/pytest-results/*.xml' - } - } - } - } - } else { - Utils.markStageSkippedForConditional('frontend: CPU') - } - }, - 'docs: GPU': { - if (!skip_ci) { - node('GPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/docs-python-gpu") { - init_git() - docker_init(ci_gpu) - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm.so build/libtvm.so - md5sum build/libtvm.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libvta_fsim.so build/libvta_fsim.so - md5sum build/libvta_fsim.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/libtvm_runtime.so build/libtvm_runtime.so - md5sum build/libtvm_runtime.so - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/config.cmake build/config.cmake - md5sum build/config.cmake - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/gpu/build/microtvm_template_projects build/microtvm_template_projects --recursive - """, - label: 'Download artifacts from S3', - ) - - add_microtvm_permissions() - timeout(time: 180, unit: 'MINUTES') { - ci_setup(ci_gpu) - sh ( - script: "${docker_run} ${ci_gpu} ./tests/scripts/task_python_docs.sh", - label: 'Build docs', - ) - } - sh( - script: """ - set -eux - . ci/scripts/retry.sh - md5sum docs.tgz - retry 3 aws s3 cp --no-progress docs.tgz s3://${s3_prefix}/docs/docs.tgz - """, - label: 'Upload artifacts to S3', - ) - - sh( - script: "aws s3 cp --no-progress _docs s3://${s3_prefix}/docs --recursive", - label: 'Upload docs to S3', - ) - } - } - } - }, - ) -} -} -/* -stage('Build packages') { - parallel 'conda CPU': { - node('CPU') { - sh "${docker_run} tlcpack/conda-cpu ./conda/build_cpu.sh - } - }, - 'conda cuda': { - node('CPU') { - sh "${docker_run} tlcpack/conda-cuda90 ./conda/build_cuda.sh - sh "${docker_run} tlcpack/conda-cuda100 ./conda/build_cuda.sh - } - } -// Here we could upload the packages to anaconda for releases -// and/or the main branch -} -*/ - - -def update_docker(ecr_image, hub_image) { - if (ecr_image == null) { - sh("image was not rebuilt, skipping") - return - } - if (!ecr_image.contains("amazonaws.com")) { - sh("echo \"Skipping '${ecr_image}' -> '${hub_image}' since it doesn\'t look like an ECR image\"") - return - } - docker_init(ecr_image) - sh( - script: """ - set -eux - . ci/scripts/retry.sh - docker tag \ - ${ecr_image} \ - ${hub_image} - retry 5 docker push ${hub_image} - """, - label: "Update ${hub_image} on Docker Hub", - ) -} - -def deploy_docs() { - // Note: This code must stay in the Jenkinsfile to ensure that it runs - // from a trusted context only - sh( - script: ''' - set -eux - rm -rf tvm-site - git clone -b $DOCS_DEPLOY_BRANCH --depth=1 https://github.com/apache/tvm-site - cd tvm-site - git status - git checkout -B $DOCS_DEPLOY_BRANCH - - git ls-tree HEAD docs/ --name-only | grep -vP '^docs/v\\d' | xargs rm -rf - mkdir -p docs - tar xf ../docs.tgz -C docs - COMMIT=$(cat docs/commit_hash) - git add . - git config user.name tvm-bot - git config user.email 95660001+tvm-bot@users.noreply.github.com - git commit -m"deploying docs (apache/tvm@$COMMIT)" - git status - ''', - label: 'Unpack docs and update tvm-site' - ) - - withCredentials([string( - credentialsId: 'docs-push-token', - variable: 'GITHUB_TOKEN', - )]) { - sh( - script: ''' - cd tvm-site - git remote add deploy https://$GITHUB_TOKEN:x-oauth-basic@github.com/apache/tvm-site.git - git push deploy $DOCS_DEPLOY_BRANCH || true - ''', - label: 'Upload docs to apache/tvm-site' - ) - } -} - - -def deploy() { - stage('Deploy') { - if (env.BRANCH_NAME == 'main') { - parallel( - 'Deploy Docs': { - if (env.DOCS_DEPLOY_ENABLED == 'yes') { - node('CPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/deploy-docs") { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - sh( - script: """ - set -eux - . ci/scripts/retry.sh - retry 3 aws s3 cp --no-progress s3://${s3_prefix}/docs/docs.tgz docs.tgz - md5sum docs.tgz - """, - label: 'Download artifacts from S3', - ) - - deploy_docs() - } - } - } - } else { - Utils.markStageSkippedForConditional('Deploy Docs') - } - }, - 'Upload built Docker images': { - if (env.DEPLOY_DOCKER_IMAGES == 'yes' && rebuild_docker_images && upstream_revision != null) { - node('CPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/deploy-docker") { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - try { - withCredentials([string( - credentialsId: 'dockerhub-tlcpackstaging-key', - variable: 'DOCKERHUB_KEY', - )]) { - sh( - script: 'docker login -u tlcpackstaging -p ${DOCKERHUB_KEY}', - label: 'Log in to Docker Hub', - ) - } - def date_Ymd_HMS = sh( - script: 'python3 -c \'import datetime; print(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))\'', - label: 'Determine date', - returnStdout: true, - ).trim() - def tag = "${date_Ymd_HMS}-${upstream_revision.substring(0, 8)}" - update_docker(built_ci_arm, "tlcpackstaging/ci_arm:${tag}") - update_docker(built_ci_cortexm, "tlcpackstaging/ci_cortexm:${tag}") - update_docker(built_ci_cpu, "tlcpackstaging/ci_cpu:${tag}") - update_docker(built_ci_gpu, "tlcpackstaging/ci_gpu:${tag}") - update_docker(built_ci_hexagon, "tlcpackstaging/ci_hexagon:${tag}") - update_docker(built_ci_i386, "tlcpackstaging/ci_i386:${tag}") - update_docker(built_ci_lint, "tlcpackstaging/ci_lint:${tag}") - update_docker(built_ci_minimal, "tlcpackstaging/ci_minimal:${tag}") - update_docker(built_ci_riscv, "tlcpackstaging/ci_riscv:${tag}") - update_docker(built_ci_wasm, "tlcpackstaging/ci_wasm:${tag}") - } finally { - sh( - script: 'docker logout', - label: 'Clean up login credentials' - ) - } - } - } - } - } else { - Utils.markStageSkippedForConditional('Upload built Docker images') - } - }, - 'Tag tlcpackstaging to tlcpack': { - if (env.DOCS_DEPLOY_ENABLED == 'yes') { - node('CPU') { - ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/tag-images") { - timeout(time: max_time, unit: 'MINUTES') { - init_git() - withCredentials([string( - credentialsId: 'dockerhub-tlcpack-key', - variable: 'TLCPACK_TOKEN', - )]) { - try { - sh( - script: 'echo $TLCPACK_TOKEN | docker login --username octomldriazati --password-stdin', - label: 'Log in to Docker Hub' - ) - if (ci_arm.contains("tlcpackstaging")) { - // Push image to tlcpack - def tag = ci_arm.split(":")[1] - sh( - script: """ - set -eux - . ci/scripts/retry.sh - docker pull tlcpackstaging/ci_arm:${tag} - docker tag tlcpackstaging/ci_arm:${tag} tlcpack/ci-arm:${tag} - retry 5 docker push tlcpack/ci-arm:${tag} - """, - label: 'Tag tlcpackstaging/ci_arm image to tlcpack', - ) - } - if (ci_cortexm.contains("tlcpackstaging")) { - // Push image to tlcpack - def tag = ci_cortexm.split(":")[1] - sh( - script: """ - set -eux - . ci/scripts/retry.sh - docker pull tlcpackstaging/ci_cortexm:${tag} - docker tag tlcpackstaging/ci_cortexm:${tag} tlcpack/ci-cortexm:${tag} - retry 5 docker push tlcpack/ci-cortexm:${tag} - """, - label: 'Tag tlcpackstaging/ci_cortexm image to tlcpack', - ) - } - if (ci_cpu.contains("tlcpackstaging")) { - // Push image to tlcpack - def tag = ci_cpu.split(":")[1] - sh( - script: """ - set -eux - . ci/scripts/retry.sh - docker pull tlcpackstaging/ci_cpu:${tag} - docker tag tlcpackstaging/ci_cpu:${tag} tlcpack/ci-cpu:${tag} - retry 5 docker push tlcpack/ci-cpu:${tag} - """, - label: 'Tag tlcpackstaging/ci_cpu image to tlcpack', - ) - } - if (ci_gpu.contains("tlcpackstaging")) { - // Push image to tlcpack - def tag = ci_gpu.split(":")[1] - sh( - script: """ - set -eux - . ci/scripts/retry.sh - docker pull tlcpackstaging/ci_gpu:${tag} - docker tag tlcpackstaging/ci_gpu:${tag} tlcpack/ci-gpu:${tag} - retry 5 docker push tlcpack/ci-gpu:${tag} - """, - label: 'Tag tlcpackstaging/ci_gpu image to tlcpack', - ) - } - if (ci_hexagon.contains("tlcpackstaging")) { - // Push image to tlcpack - def tag = ci_hexagon.split(":")[1] - sh( - script: """ - set -eux - . ci/scripts/retry.sh - docker pull tlcpackstaging/ci_hexagon:${tag} - docker tag tlcpackstaging/ci_hexagon:${tag} tlcpack/ci-hexagon:${tag} - retry 5 docker push tlcpack/ci-hexagon:${tag} - """, - label: 'Tag tlcpackstaging/ci_hexagon image to tlcpack', - ) - } - if (ci_i386.contains("tlcpackstaging")) { - // Push image to tlcpack - def tag = ci_i386.split(":")[1] - sh( - script: """ - set -eux - . ci/scripts/retry.sh - docker pull tlcpackstaging/ci_i386:${tag} - docker tag tlcpackstaging/ci_i386:${tag} tlcpack/ci-i386:${tag} - retry 5 docker push tlcpack/ci-i386:${tag} - """, - label: 'Tag tlcpackstaging/ci_i386 image to tlcpack', - ) - } - if (ci_lint.contains("tlcpackstaging")) { - // Push image to tlcpack - def tag = ci_lint.split(":")[1] - sh( - script: """ - set -eux - . ci/scripts/retry.sh - docker pull tlcpackstaging/ci_lint:${tag} - docker tag tlcpackstaging/ci_lint:${tag} tlcpack/ci-lint:${tag} - retry 5 docker push tlcpack/ci-lint:${tag} - """, - label: 'Tag tlcpackstaging/ci_lint image to tlcpack', - ) - } - if (ci_minimal.contains("tlcpackstaging")) { - // Push image to tlcpack - def tag = ci_minimal.split(":")[1] - sh( - script: """ - set -eux - . ci/scripts/retry.sh - docker pull tlcpackstaging/ci_minimal:${tag} - docker tag tlcpackstaging/ci_minimal:${tag} tlcpack/ci-minimal:${tag} - retry 5 docker push tlcpack/ci-minimal:${tag} - """, - label: 'Tag tlcpackstaging/ci_minimal image to tlcpack', - ) - } - if (ci_riscv.contains("tlcpackstaging")) { - // Push image to tlcpack - def tag = ci_riscv.split(":")[1] - sh( - script: """ - set -eux - . ci/scripts/retry.sh - docker pull tlcpackstaging/ci_riscv:${tag} - docker tag tlcpackstaging/ci_riscv:${tag} tlcpack/ci-riscv:${tag} - retry 5 docker push tlcpack/ci-riscv:${tag} - """, - label: 'Tag tlcpackstaging/ci_riscv image to tlcpack', - ) - } - if (ci_wasm.contains("tlcpackstaging")) { - // Push image to tlcpack - def tag = ci_wasm.split(":")[1] - sh( - script: """ - set -eux - . ci/scripts/retry.sh - docker pull tlcpackstaging/ci_wasm:${tag} - docker tag tlcpackstaging/ci_wasm:${tag} tlcpack/ci-wasm:${tag} - retry 5 docker push tlcpack/ci-wasm:${tag} - """, - label: 'Tag tlcpackstaging/ci_wasm image to tlcpack', - ) - } - } finally { - sh( - script: 'docker logout', - label: 'Clean up login credentials' - ) - } - } - } - } - } - } else { - Utils.markStageSkippedForConditional('Tag tlcpackstaging to tlcpack') - } - }, - ) - } - } -} - - -cancel_previous_build() - -prepare() - -if (rebuild_docker_images) { - build_docker_images() -} - -lint() - -build() - -test() - -deploy() diff --git a/apps/relax_examples/e2e_auto_tir.py b/apps/relax_examples/e2e_auto_tir.py new file mode 100644 index 0000000000..92cda16f79 --- /dev/null +++ b/apps/relax_examples/e2e_auto_tir.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime +import os +import csv +import json +import argparse +import logging +from typing import Dict +import numpy as np # type: ignore + +import tvm +from tvm import relay, relax, runtime, transform +from tvm.ir.module import IRModule +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc +from tvm.relax.testing import relay_translator +from tvm.target.target import Target + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--input-shape", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + default=None, + ) + args.add_argument( + "--rpc-port", + type=int, + default=None, + ) + args.add_argument( + "--rpc-key", + type=str, + default=None, + ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) + args.add_argument( + "--cache-dir", + type=str, + default=None, + ) + args.add_argument( + "--rpc-timeout-sec", + type=int, + default=180, + ) + args.add_argument("--num-measurement-repeats", type=int, default=5) + args.add_argument("--num-measurements", type=int, default=10) + args.add_argument("--results-file", type=str, required=False, default=None) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.input_shape = json.loads(parsed.input_shape) + if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu": + parsed.alloc_repeat = 3 + else: + parsed.alloc_repeat = 1 + if parsed.rpc_host and parsed.rpc_port and parsed.rpc_key: + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=parsed.rpc_timeout_sec, + ) + parsed.workers = parsed.rpc_config.count_num_servers(allow_missing=False) + else: + # check all rpc configs are None + assert ( + (parsed.rpc_host is None) and (parsed.rpc_port is None) and (parsed.rpc_key is None) + ), "Please set all 'rpc_host', 'rpc_port' and 'rpc_key' to use PRC server" + parsed.rpc_config = None + parsed.workers = 1 + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def apply_opt_before_tuning( + relay_mod: IRModule, params: Dict[str, runtime.NDArray], target: Target +): + with transform.PassContext(opt_level=3): + main_func = relay_mod["main"] + bind_main_func = relay.build_module.bind_params_by_name(main_func, params) + relay_mod = IRModule.from_expr(bind_main_func) + relay_mod = relay.transform.SimplifyInference()(relay_mod) + relay_mod = relay.transform.FoldConstant()(relay_mod) + relay_mod = relay.transform.FoldScaleAxis()(relay_mod) + relay_mod = relay.transform.CanonicalizeOps()(relay_mod) + relay_mod = relay.transform.AlterOpLayout()(relay_mod) + relay_mod = relay.transform.FoldConstant()(relay_mod) + + relax_mod = relay_translator.from_relay(relay_mod["main"], target=target) + relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod) + relax_mod = relax.transform.FuseOps()(relax_mod) + relax_mod = relax.transform.FuseTIR()(relax_mod) + return relax_mod + + +def f_measurement( + rt_mod: runtime.Module, device: runtime.ndarray.Device, input_data: Dict[str, runtime.NDArray] +): + vm = relax.vm.VirtualMachine(exec=rt_mod, device=device) + vm.save_function("main", "measure_func", **input_data, include_return=False) + evaluator = vm.time_evaluator( + func_name="measure_func", + dev=device, + repeat=ARGS.num_measurement_repeats, + number=ARGS.num_measurements, + min_repeat_ms=500, + ) + return evaluator() + + +def get_runner(): + runner_config = { + "evaluator_config": ms.runner.EvaluatorConfig( + number=3, + repeat=1, + min_repeat_ms=100, + enable_cpu_cache_flush=False, + ), + "alloc_repeat": ARGS.alloc_repeat, + } + if ARGS.rpc_config: + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, max_workers=ARGS.workers, **runner_config + ) + else: + runner = ms.runner.LocalRunner(**runner_config) + + return runner + + +def main(): + relay_mod, params, (input_name, input_shape, input_dtype) = get_network( + ARGS.workload, + ARGS.input_shape, + cache_dir=ARGS.cache_dir, + ) + input_info = {input_name: input_shape} + input_data = {} + for input_name, input_shape in input_info.items(): + print(f" input_name: {input_name}") + print(f" input_shape: {input_shape}") + print(f" input_dtype: {input_dtype}") + + # translate the ResNet model from Relay to Relax + relax_mod = apply_opt_before_tuning(relay_mod, params, target=ARGS.target) + assert isinstance(relax_mod, tvm.IRModule) + + db = ms.relax_integration.tune_relax( + mod=relax_mod, + target=ARGS.target, + params=params, + num_trials_per_iter=64, + max_trials_per_task=ARGS.num_trials, + max_trials_global=ARGS.num_trials, + runner=get_runner(), + work_dir=ARGS.work_dir, + ) + executable = ms.relax_integration.compile_relax( + db, + mod=relax_mod, + target=ARGS.target, + params=params, + ) + + for input_name, input_shape in input_info.items(): + if input_dtype.startswith("float"): + input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype) + else: + input_data[input_name] = np.random.randint( + low=0, high=10000, size=input_shape, dtype=input_dtype + ) + + # for documentation purposes + start_time = datetime.datetime.now() + + if ARGS.rpc_config: + result = run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=executable.mod, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_measurement, + ) + else: + dev = tvm.device(ARGS.target.kind.name) + result = f_measurement(executable.mod, dev, input_data) + + print(result) + + if not ARGS.results_file: + return + + out_path = os.path.abspath(os.path.expanduser(ARGS.results_file)) + with open(out_path, "w") as out_file: + writer = csv.writer(out_file) + # write experiment parameters at the top as a record + writer.writerow(["start", str(start_time)]) + writer.writerow(["workload", ARGS.workload]) + writer.writerow(["input_shape", ARGS.input_shape]) + writer.writerow(["target", ARGS.target]) + writer.writerow(["num_measurement_repeats", ARGS.num_measurement_repeats]) + for res in result.results: + writer.writerow([str(res)]) + + +if __name__ == "__main__": + main() diff --git a/apps/relax_examples/mlp.py b/apps/relax_examples/mlp.py new file mode 100644 index 0000000000..fa69524a80 --- /dev/null +++ b/apps/relax_examples/mlp.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Example code on creating, compiling, and running an MLP model in relax + + +import tvm +from tvm.relay import Call +from tvm import relax, tir, topi +import numpy as np + + +def build_mlp(data, weight): + bb = relax.BlockBuilder() + + with bb.function("mlp", [data, weight]): + gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False) + gv1 = bb.emit_te(topi.nn.relu, gv0) + bb.emit_func_output(gv1) + + mod = bb.get() + return mod + + +if __name__ == "__main__": + # symbolic dimensions + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + # create data and weight variables + data = relax.Var("data", [n, m], relax.DynTensorType(2, "float32")) + weight = relax.Var("weight", [m, n], relax.DynTensorType(2, "float32")) + + # construct a mlp model + mod = build_mlp(data, weight) + + # build and create vm executor + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # run the mlp model on relax vm + data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + res = vm["mlp"](data, weight) + print(res) diff --git a/apps/relax_examples/nn_module.py b/apps/relax_examples/nn_module.py new file mode 100644 index 0000000000..45405ae398 --- /dev/null +++ b/apps/relax_examples/nn_module.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Example code on creating, compiling, and running a neural network with pytorch-like API + + +import tvm +from tvm.relay import Call +from tvm import relax, tir +from tvm.relax.testing import nn +from tvm.script import relax as R +import numpy as np + + +if __name__ == "__main__": + builder = relax.BlockBuilder() + + # a symbolic variable to represent minibatch size + n = tir.Var("n", "int64") + input_size = 784 + hidden_sizes = [128, 32] + output_size = 10 + + # build a three linear-layer neural network for a classification task + with builder.function("main"): + model = nn.Sequential( + nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LogSoftmax(), + ) + data = nn.Placeholder((n, input_size), name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + + # get and print the IRmodule being built + mod = builder.get() + print(R.parser.astext(mod)) + + # build the IRModule and create relax vm + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # init parameters + params = nn.init_params(mod) + + # run the model on relax vm + # the input data has a minibatch size of 3 + data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32)) + res = vm["main"](data, *params) + print(res) diff --git a/apps/relax_examples/resnet.py b/apps/relax_examples/resnet.py new file mode 100644 index 0000000000..3afb00c3eb --- /dev/null +++ b/apps/relax_examples/resnet.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example ResNet workload by translating the Relay program to Relax""" + +import tvm +import tvm.testing +from tvm.relay import testing +from tvm import relax, relay +from tvm.relax.testing import relay_translator, nn +from tvm.runtime import vm as vm_rt +from tvm.script import relax as R +import numpy as np + +if __name__ == "__main__": + relay_mod, _ = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32") + + # translate the ResNet model from Relay to Relax + target = tvm.target.Target("llvm", host="llvm") + relax_mod = relay_translator.from_relay(relay_mod["main"], target) + + # print the ResNet IRmodule got translated + print(R.parser.astext(relax_mod)) + + # build the IRModule and create relax vm + ex = relax.vm.build(relax_mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + # init weights and run the model on relax vm + shape = (1, 3, 224, 224) + data = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + params = nn.init_params(relax_mod) + res = vm["main"](data, *params) + + # check correctness by comparing with relay result + exe = relay.vm.compile(relay_mod, target) + relay_vm = vm_rt.VirtualMachine(exe, tvm.cpu()) + inputs = [data] + params + expected_output = relay_vm.run(*inputs) + tvm.testing.assert_allclose(res.numpy(), expected_output.numpy(), rtol=1e-4, atol=1e-4) diff --git a/cmake/modules/contrib/TensorRT.cmake b/cmake/modules/contrib/TensorRT.cmake index 696108b501..a749b6e80f 100644 --- a/cmake/modules/contrib/TensorRT.cmake +++ b/cmake/modules/contrib/TensorRT.cmake @@ -23,7 +23,7 @@ include (FindPackageHandleStandardArgs) if(USE_TENSORRT_CODEGEN) message(STATUS "Build with TensorRT codegen") - tvm_file_glob(GLOB COMPILER_TENSORRT_SRCS src/relay/backend/contrib/tensorrt/*.cc) + tvm_file_glob(GLOB COMPILER_TENSORRT_SRCS src/relay/backend/contrib/tensorrt/*.cc src/relax/backend/contrib/tensorrt/*.cc) set_source_files_properties(${COMPILER_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") tvm_file_glob(GLOB RUNTIME_TENSORRT_SRCS src/runtime/contrib/tensorrt/tensorrt_runtime.cc) set_source_files_properties(${RUNTIME_TENSORRT_SRCS} PROPERTIES COMPILE_FLAGS "-Wno-deprecated-declarations") diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b9afb4be2d..08f21bab8a 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -133,6 +133,7 @@ class PrimExpr : public BaseExpr { TVM_DLL static PrimExpr FromObject_(ObjectRef ref); }; +class RelayExpr; /*! * \brief add operator * @@ -365,10 +366,28 @@ class RelayExprNode : public BaseExprNode { * This value is discarded during serialization. */ mutable Type checked_type_ = Type(nullptr); + + /*! + * \brief Stores the result of static shape analysis. It must be a RelayExpr + * and ObjectRef is used here to avoid cyclic typing. + * + * \note The value will be optional if a static shape can not be inferred. + * use .shape() instead to acesss an always defined shape expression. + */ + mutable Optional shape_ = Optional(); + /*! * \return The checked_type */ inline const Type& checked_type() const; + + /*! + * \return An expression which corresponds to the shape of the expression. + * + * Only valid when the expression's type is a Tensor. + */ + RelayExpr shape() const; + /*! * \brief Check if the inferred(checked) type of the Expr * is backed by a TTypeNode and return it. diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 1493544e73..c63f8efe9e 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -65,6 +65,68 @@ enum class CallingConv : int { kDeviceKernelLaunch = 2, }; +/*! + * \brief Supported linkage types. + */ +enum class LinkageType : int { + /*! + * \brief Internal linkage. + */ + kInternal = 0, + /*! + * \brief External linkage. + - Function with external linkage should have a global symbol attached to it. + */ + kExternal = 1 +}; + +/*! + * \brief Generic attribute names that can be attached to any function. + * + * \sa tvm::tir::attr, tvm::relay::attr + */ +namespace attr { +/*! + * \brief Indicates the special calling convention. + * + * Type: Integer + * + * \sa tvm::CallingConv + */ +constexpr const char* kCallingConv = "calling_conv"; + +/*! + * \brief Compilation target of the function. + * + * Type: Target + * + * \sa tvm::Target + */ +constexpr const char* kTarget = "target"; + +/*! + * \brief Global linker symbol of the function in generated code. + * + * This option forces the code generator to name the + * function with the given. + * + * For example, we could set a global_symbol of a function + * early to make sure that we can always refer to it by + * the symbol name in the generated DLL. + * + * We should not set the attribute for local functions, + * so that the compiler can freely rename them. + * + * A unique global symbol will be automatically assigned + * to each function in the module before the target code + * generation phase. + * + * Type: String + */ +constexpr const char* kGlobalSymbol = "global_symbol"; + +} // namespace attr + /*! * \brief Base node of all functions. * @@ -131,6 +193,32 @@ class BaseFuncNode : public RelayExprNode { */ bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); } + /*! + * \brief Get the type of the linkage. + * + * Currently, we only consider external/internal linkage. + * This can be extended in the future when necessary. + * + * \return Linkage type. + * + * \code + * + * void Example(const BaseFunc& f) { + * if (f->GetLinkageType() == tvm::LinkageType::kExternal) { + * // Do not remove a function with external linkage + * } + * } + * + * \endcode + */ + + LinkageType GetLinkageType() const { + if (GetAttr(attr::kGlobalSymbol)) + return LinkageType::kExternal; + else + return LinkageType::kInternal; + } + static constexpr const char* _type_key = "BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode); @@ -145,51 +233,5 @@ class BaseFunc : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); }; -/*! - * \brief Generic attribute names that can be attached to any function. - * - * \sa tvm::tir::attr, tvm::relay::attr - */ -namespace attr { -/*! - * \brief Indicates the special calling convention. - * - * Type: Integer - * - * \sa tvm::CallingConv - */ -constexpr const char* kCallingConv = "calling_conv"; - -/*! - * \brief Compilation target of the function. - * - * Type: Target - * - * \sa tvm::Target - */ -constexpr const char* kTarget = "target"; - -/*! - * \brief Global linker symbol of the function in generated code. - * - * This option forces the code generator to name the - * function with the given. - * - * For example, we could set a global_symbol of a function - * early to make sure that we can always refer to it by - * the symbol name in the generated DLL. - * - * We should not set the attribute for local functions, - * so that the compiler can freely rename them. - * - * A unique global symbol will be automatically assigned - * to each function in the module before the target code - * generation phase. - * - * Type: String - */ -constexpr const char* kGlobalSymbol = "global_symbol"; - -} // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 7313b4f783..7ae5adfe40 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -95,6 +95,12 @@ class IRModuleNode : public Object { return GetAttr(attr_key, Optional(default_value)); } + /*! + * \brief Get the metadata attributes. + * \returns The additional meta-data attributes + */ + DictAttrs GetAttrs() const { return attrs; } + /*! * \brief Check whether the module has an non-zero integer attr. * @@ -357,7 +363,7 @@ class IRModule : public ObjectRef { * \param type_definitions Type definitions in the module. * \param import_set Set of imported files in the module. * \param map The module source map. - * \param attrs The module attributes. + * \param attrs The module meta-data attributes. */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 6e6b8bee5f..5fda19eab5 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -146,6 +146,9 @@ class OpNode : public RelayExprNode { // Internal function to compute if it is primitive op bool IsPrimitiveOp_() const { const auto& fn_ty = this->op_type; + if (!fn_ty.get()) { + return false; + } ICHECK(fn_ty.get() != nullptr) << "op_type of " << this->name << " is not registered"; if (fn_ty->type_constraints.size() != 1) return false; const TypeRelationNode* rel = fn_ty->type_constraints[0].as(); diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index febcca5c01..d22568b3bc 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -32,18 +32,18 @@ * - Reducing the effort required to implement new passes for compiler * developers, etc. * - * Similar to LLVM's pass manager, we designed the Relay pass manager to work + * Similar to LLVM's pass manager, we designed the Relay/Relax pass manager to work * different granularity, i.e. module level, function level, and even sequential * passe that contains a host of passes. * * However, we also extend the functionality of the traditional pass manager * with the consideration of requirements/convention from deep learning - * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass + * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay/Relax pass * manager performs the IRModule -> IRModule transformation. All * different types of passes, including the sequential-level pass object, are * essentially pass objects. This design, therefore, effectively provides users * a consistent and convenient interface, i.e. Pass, to play with. It offers a - * means to ease the development and testing of Relay passes. For example, with + * means to ease the development and testing of Relay/Relax passes. For example, with * the pass manager, external users will be able to have custom passes correctly * scheduled without having to modify a single handcrafted pass order. * @@ -79,7 +79,6 @@ class PassContextNode : public Object { public: /*! \brief The default optimization level. */ int opt_level{2}; - /*! \brief The list of required passes. */ Array required_pass; /*! \brief The list of disabled passes. */ @@ -88,12 +87,19 @@ class PassContextNode : public Object { mutable Optional diag_ctx; /*! \brief Pass specific configurations. */ Map config; - /*! \brief A list of pass instrument implementations. */ Array instruments; - + // TODO(@sunggg): Fix dependency issue in the header file and correct the types + // e.g., relax::trace, relax::database in tvm/relax/tuning_api.h + /*! \brief Trace stack for relax pass infra. */ + mutable Array trace_stack; + /*! \brief List of passes to be traced. If not defined, make every pass traceable. */ + Optional> make_traceable; + /*! \brief Number of evaluations conducted in the pass pipeline. */ + mutable int num_evals{0}; + /*! \brief Database for tuning API. */ + Optional tuning_api_database; PassContextNode() = default; - /*! * \brief Get a config value from the pass context. * @@ -131,7 +137,27 @@ class PassContextNode : public Object { v->Visit("instruments", &instruments); v->Visit("config", &config); v->Visit("diag_ctx", &diag_ctx); + v->Visit("trace_stack", &trace_stack); + v->Visit("make_traceable", &make_traceable); + v->Visit("num_evals", &num_evals); + v->Visit("tuning_api_daatabase", &tuning_api_database); + } + + Array GetTraceStack() { return trace_stack; } + void PushTrace(ObjectRef new_trace) { trace_stack.push_back(new_trace); } + void PopTrace() { + ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check."; + trace_stack.pop_back(); + } + int GetTraceStackSize() { return trace_stack.size(); } + ObjectRef GetCurrentTrace() { + ICHECK(GetTraceStackSize()) << "Trace stack is currently empty. Please double check."; + return trace_stack.back(); } + void SetNumEvals(int _num_evals) { num_evals = _num_evals; } + void IncNumEvals(int _num_evals) { num_evals += _num_evals; } + + Optional GetTuningAPIDatabase() { return tuning_api_database; } static constexpr const char* _type_key = "transform.PassContext"; static constexpr bool _type_has_method_sequal_reduce = false; @@ -280,6 +306,7 @@ class PassContext : public ObjectRef { * \brief Meta data that will be used to help optimization and analysis. * \sa PassInfo */ + class PassInfoNode : public Object { public: /*! \brief The minimal optimization level that this pass will be enabled. */ @@ -288,6 +315,9 @@ class PassInfoNode : public Object { /*! \brief The name of an optimization/analysis pass. */ String name; + /*! \brief Boolean that tells whether this pass will be traced or not. */ + bool traceable; + /*! \brief The passes that are required to perform the current pass. */ Array required; @@ -297,6 +327,7 @@ class PassInfoNode : public Object { v->Visit("opt_level", &opt_level); v->Visit("name", &name); v->Visit("required", &required); + v->Visit("traceable", &traceable); } static constexpr const char* _type_key = "transform.PassInfo"; @@ -315,8 +346,9 @@ class PassInfo : public ObjectRef { * \param opt_level The optimization level * \param name Name of the pass. * \param required The passes that are required to perform the current pass. + * \param traceable Boolean that tells whether the pass is traceable. */ - TVM_DLL PassInfo(int opt_level, String name, Array required); + TVM_DLL PassInfo(int opt_level, String name, Array required, bool traceable); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -324,7 +356,7 @@ class PassInfo : public ObjectRef { /*! * \brief PassNode is the base type of differnt types of optimization passes. * It is designed as a pure class and implemented by different pass subclasses - * at different granularity of Relay nodes. + * at different granularity of Relay/Relax nodes. */ class PassNode : public Object { public: @@ -397,7 +429,7 @@ class Pass : public ObjectRef { }; /*! - * \brief The SequentialNode contains a set of passes that transform Relay + * \brief The SequentialNode contains a set of passes that transform Relay/Relax * programs from one AST to another semantically equivalent one. * * One example of this level of pass is that the pass manager needs to correctly @@ -490,9 +522,9 @@ class Sequential : public Pass { * * \return The created module pass. */ -TVM_DLL Pass -CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, Array required); +TVM_DLL Pass CreateModulePass( + const runtime::TypedPackedFunc& pass_func, int opt_level, + String name, Array required, bool traceable = false); /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 579061e02e..12bb127d01 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -377,7 +377,7 @@ class TupleType : public Type { inline Type VoidType() { return TupleType::Empty(); } /*! - * \brief Check whether the tyep represents void. + * \brief Check whether the type represents void. * \return The check result. */ inline bool IsVoidType(const Type& type) { diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index 11bf7d4740..54dc3aa57c 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -25,6 +25,7 @@ #define TVM_IR_TYPE_FUNCTOR_H_ #include +#include #include #include @@ -89,6 +90,10 @@ class TypeFunctor { virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const PointerTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const relax::ShapeTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const relax::ObjectTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const relax::DynTensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const relax::DimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; // unreachable, written to stop compiler warning @@ -112,6 +117,10 @@ class TypeFunctor { TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode); TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode); TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(relax::ShapeTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(relax::ObjectTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(relax::DynTensorTypeNode); + TVM_TYPE_FUNCTOR_DISPATCH(relax::DimTypeNode); return vtable; } }; diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 4092fdae36..37aaf09aa8 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -404,6 +404,9 @@ class PyDatabaseNode : public DatabaseNode { */ class Database : public runtime::ObjectRef { public: + /*! Default constructor */ + Database() = default; + /*! An in-memory database. */ TVM_DLL static Database MemoryDatabase(); /*! diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index 8b8a403326..638d6865ce 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file tvm/node/structural_equal.h + * \file tvm/node/structural_hash.h * \brief Structural hash class. */ #ifndef TVM_NODE_STRUCTURAL_HASH_H_ @@ -174,7 +174,7 @@ class SHashReducer { /*! * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. - * \note This function indicate key could contain var defintions. + * \note This function indicates key could contain variable defintions. */ void DefHash(const ObjectRef& key) const { return handler_->SHashReduce(key, true); } /*! diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h new file mode 100644 index 0000000000..f922a9f5c2 --- /dev/null +++ b/include/tvm/relax/analysis.h @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis.h + * \brief The set of Relax specific analysis passes. + */ +#ifndef TVM_RELAX_ANALYSIS_H_ +#define TVM_RELAX_ANALYSIS_H_ + +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Check if the IRModule is well formed. + * + * \param m the IRModule to check. + * \param diag_ctx the diagnostic context. + * \return true if the IRModule is well formed, false if not. + */ +TVM_DLL bool WellFormed(const IRModule& m, + Optional diag_ctx = Optional()); + +/*! + * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. + * + * \param func The PrimFunc to be analyzed. + * \return The Op Pattern Kind. + * + * \note This analysis applies on TIR function but is primarily used by relax passes. + * As a result we place it under the relax namespace. + */ +TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func); + +/*! + * \brief Gather all shape variables from expression expr. + * + * This analysis is intended to be called on shape expressions (those set as the shape_ of another + * expression). + * + * \param expr the expression. Meant to be a shape expression. + * + * \return List of shape variables (tir::Var) + */ +TVM_DLL tvm::Array ShapeVars(const Expr& expr); + +/*! + * \brief Get all bound variables from expression expr. + * + * Bound variables are all variables that are declared in the expr. + * They only have meaning inside that expr, and can only be used in it. + * + * \param expr the expression. + * + * \return List of bound vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array BoundVars(const Expr& expr); + +/*! + * \brief Get free type parameters from expression expr. + * + * Free variables are variables that are not bound by a + * varbinding or a function parameter in the context. + * + * \param expr the expression. + * + * \return List of free vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array FreeVars(const Expr& expr); + +/*! + * \brief Get all variables from expression expr. + * + * \param expr the expression. + * + * \return List of all vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array AllVars(const Expr& expr); + +/*! + * \brief Get all glabal variables used in calls in expression expr. + * + * \param expr the expression. + * + * \return List of all global variables called in expr. + */ +TVM_DLL tvm::Array CalledGlobalVars(const Expr& expr); + +/*! + * \brief Get all glabal variables from expression expr. + * + * AllVars is a superset of BoundVars and FreeVars. + * The union of BoundVars and FreeVars is Allvars. + * + * \param expr the expression. + * + * \return List of all global variables, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array AllGlobalVars(const Expr& expr); + +/*! + * \brief Analyze var -> value mapping from VarBindings. + * + * \param m The IRModule to check. + * \return Var -> Value (Expr) + */ +TVM_DLL Map AnalyzeVar2Value(const IRModule& m); + +/*! + * \brief Analyze var -> value mapping from VarBindings. + * + * \param expr The expression to check. + * \return Var -> Value (Expr) + */ +TVM_DLL Map AnalyzeVar2Value(const Expr& expr); + +/*! + * \brief Analyze var -> value mapping from VarBindings. + * + * \param dfb The dataflow block to check. + * \return Var -> Value (Expr) + */ +TVM_DLL Map AnalyzeVar2Value(const DataflowBlock& dfb); + +/*! + * \brief Return a mapping from variable name to its Bindings. + * + * \param fn The function to be analyzed. + * \return A mapping from variable name to its Bindings. + */ +TVM_DLL Map> NameToBinding(const Function& fn); + +/*! + * \brief Get the use-def chain of variables inside a dataflow block. + * + * \param dfb The dataflow block to be analyzed. + * \return A map mapping variable definitoins to a set of uses. + */ +TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); + +/*! + * \brief Get the use-def chain of variables inside a function. + * + * \param fn The function to be analyzed. + * \return A map from variable definitoins to a set of uses and variables needed by return value. + */ +std::pair>, Array> FunctionUseDef(const Function& fn); + +/*! + * \brief Remove unused statements inside DataflowBlocks. + * + * \param fn The function to remove unused statements. + * \return The function that contains no unused statements in DataflowBlock. + */ +TVM_DLL Function RemoveAllUnused(const Function fn); + +/*! + * \brief Given the argument vars and body, derives a return shape for a function with those args + * and that body. If the body's shape contains free shape vars (those not used in the args), the + * return shape is relaxed to RuntimeDepShape; otherwise, the body's shape is used. + * + * \param args The argument variables, ideally with the shape_ field filled in + * \param body The functino body, ideally with the shape_ field filled in + * \return An expression that can serve as the return shape for the function + */ +TVM_DLL Expr DeriveFuncRetShape(Array args, Expr body); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ANALYSIS_H_ diff --git a/include/tvm/relax/attrs/memory.h b/include/tvm/relax/attrs/memory.h new file mode 100644 index 0000000000..66e3fbf1b1 --- /dev/null +++ b/include/tvm/relax/attrs/memory.h @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/memory.h + * \brief Attributes for memory operators. + */ +#ifndef TVM_RELAX_ATTRS_MEMORY_H_ +#define TVM_RELAX_ATTRS_MEMORY_H_ + +#include + +#include "tvm/ir/memory_pools.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Attributes for allocating tensor. + */ +struct AllocTensorAttrs : public tvm::AttrsNode { + DataType dtype; + int64_t runtime_device_index; + Array candidate_memory_pools; + + TVM_DECLARE_ATTRS(AllocTensorAttrs, "relax.attrs.AllocTensorAttrs") { + TVM_ATTR_FIELD(dtype).describe("The datatype of the tensor to be allocated."); + TVM_ATTR_FIELD(runtime_device_index) + .describe( + "The device index indicating on which device the tensor is to be allocated at runtime. " + "Index -1 is reserved for the host device.") + .set_default(-1); + TVM_ATTR_FIELD(candidate_memory_pools) + .describe("The candidate memory pools when USMP is used. Empty if USMP is not used.") + .set_default(Array()); + } +}; + +/*! + * \brief Attributes for allocating storage on Relax VM. + */ +struct VMAllocStorageAttrs : public tvm::AttrsNode { + DataType dtype; + int64_t runtime_device_index; + + TVM_DECLARE_ATTRS(VMAllocStorageAttrs, "relax.attrs.VMAllocStorageAttrs") { + TVM_ATTR_FIELD(dtype) + .describe("The dtype of the tensor to allocate.") + .set_default(DataType::Float(32, 1)); + TVM_ATTR_FIELD(runtime_device_index) + .describe( + "The device index indicating on which device the tensor is to be allocated at runtime. " + "Index -1 is reserved for the host device.") + .set_default(-1); + } +}; + +/*! + * \brief Attributes for allocating tensor on Relax VM. + */ +struct VMAllocTensorAttrs : public tvm::AttrsNode { + int offset; + DataType dtype; + + TVM_DECLARE_ATTRS(VMAllocTensorAttrs, "relax.attrs.VMAllocTensorAttrs") { + TVM_ATTR_FIELD(offset).describe("Storage offset to allocate the tensor.").set_default(0); + TVM_ATTR_FIELD(dtype) + .describe("The dtype of the tensor to allocate.") + .set_default(DataType::Float(32, 1)); + } +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_ATTRS_MEMORY_H_ diff --git a/include/tvm/relax/attrs/shape.h b/include/tvm/relax/attrs/shape.h new file mode 100644 index 0000000000..9c4aaad24b --- /dev/null +++ b/include/tvm/relax/attrs/shape.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/shape.h + * \brief Attributes for shape operators. + */ +#ifndef TVM_RELAX_ATTRS_SHAPE_H_ +#define TVM_RELAX_ATTRS_SHAPE_H_ + +#include + +namespace tvm { +namespace relax { +/*! + * \brief Attributes for decoding/making shape to/from VM heap. + */ +struct ShapeHeapAttrs : public tvm::AttrsNode { + Array indices; + + TVM_DECLARE_ATTRS(ShapeHeapAttrs, "relax.attrs.ShapeHeapAttrs") { + TVM_ATTR_FIELD(indices).describe("The indices of the heap to store/load the shape to/from."); + } +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_ATTRS_SHAPE_H_ diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h new file mode 100644 index 0000000000..596905ae9d --- /dev/null +++ b/include/tvm/relax/backend.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/backend.h + * \brief Relax backend specific transformation passes. + */ +#ifndef TVM_RELAX_BACKEND_H_ +#define TVM_RELAX_BACKEND_H_ + +#include + +namespace tvm { +namespace relax { +namespace transform { + +/*! + * \brief Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics. + * + * \return The Pass. + */ +TVM_DLL Pass VMMemoryLower(); + +/*! + * \brief Lower the shape expression in relax to VM shape heap and TIR functions. + * + * \return The Pass. + */ +TVM_DLL Pass VMShapeLower(); + +} // namespace transform +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_H_ diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h new file mode 100644 index 0000000000..7cee0bf1d5 --- /dev/null +++ b/include/tvm/relax/binding_rewrite.h @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/binding_rewrite.h + * \brief An IR rewriter to easily add/remove/replace bindings (statements). + */ + +#ifndef TVM_RELAX_BINDING_REWRITE_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Statement rewriter for relax.DataflowBlock. */ +class DataflowBlockRewriteNode : public Object { + public: + /*! \brief Replace all uses of old_var with new_var. */ + void ReplaceAllUses(Var old_var, Var new_var); + /*! \brief Insert a Binding statement. */ + void Add(Binding binding); + /*! \brief Insert an expression as VarBinding with variable name. */ + void Add(String var_name, Expr expr, bool is_dfvar = false) { + auto var = is_dfvar ? DataflowVar(var_name, expr->shape(), expr->checked_type()) + : Var(var_name, expr->shape(), expr->checked_type()); + Add(VarBinding(std::move(var), std::move(expr))); + } + /*! \brief Insert an expression as VarBinding with automatic variable name. */ + void Add(Expr expr, bool is_dfvar = false) { + Add(name_table_.GetUniqueName("tmp"), expr, is_dfvar); + } + /*! \brief Remove the definition statement of an unused variable. */ + void RemoveUnused(Var unused, bool allow_undef = false); + /*! \brief Remove the definition statements of all unused variables. */ + void RemoveAllUnused(); + + /*! \brief The rewritten dataflow block. */ + DataflowBlock MutatedDataflowBlock() { return dfb_.value(); } + /*! \brief The rewritten function. */ + Function MutatedFunc() { return root_fn_.value(); } + /*! \brief The rewritten IRModule. */ + IRModule MutateIRModule(IRModule irmod); + + /*! \brief Visit attributes. */ + void VisitAttrs(AttrVisitor* v) { + v->Visit("dfb", &dfb_); + v->Visit("root_fn", &root_fn_); + } + + static constexpr const char* _type_key = "relax.DataflowBlockRewrite"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockRewriteNode, Object); + + protected: + friend class DataflowBlockRewrite; + + Optional dfb_; //!< The rewritten dataflow block. + Optional root_fn_; //!< The rewritten function. + const FunctionNode* original_fn_ptr_; //!< Pointer to the original function. + Map> to_users_; //!< Map from variable to its users. + Array fn_outputs_; //!< Variables required by function outputs. + + private: + NameTable name_table_; //!< Name table for tracking and generating unique names. +}; + +/*! + * \brief A statement rewriter for relax.DataflowBlock. + * \sa DataflowBlockRewriteNode + */ +class DataflowBlockRewrite : public ObjectRef { + public: + TVM_DLL explicit DataflowBlockRewrite(DataflowBlock dfb, Function root_fn); + + /*! + * \brief mutable accessor. + * \return mutable access pointer. + */ + DataflowBlockRewriteNode* operator->() { + ICHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockRewrite, ObjectRef, DataflowBlockRewriteNode); +}; + +} // namespace relax +} // namespace tvm + +#define TVM_RELAX_BINDING_REWRITE_H_ +#endif // TVM_RELAX_BINDING_REWRITE_H_ diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h new file mode 100644 index 0000000000..409399713e --- /dev/null +++ b/include/tvm/relax/block_builder.h @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/block_builder.h + * \brief The utility for constructing Relax binding blocks. + */ +#ifndef TVM_RELAX_BLOCK_BUILDER_H_ +#define TVM_RELAX_BLOCK_BUILDER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class BlockBuilder; + +/*! + * \brief A builder that provides APIs to build Relax binding blocks. + */ +class BlockBuilderNode : public Object { + public: + BlockBuilderNode(); + + ~BlockBuilderNode(); + + /*! \brief Begin to build a DataflowBlock. */ + void BeginDataflowBlock(); + + /*! \brief Begin to build a BindingBlock. */ + void BeginBindingBlock(); + + /*! + * \brief End building a BindingBlock. + * \return The BindingBlock being built. + */ + BindingBlock EndBlock(); + + /*! + * \brief Check if the block being built is DataflowBlock or not. + * \return A boolean that indicates if the block being built is DataflowBlock or not. + */ + inline bool CurrentBlockIsDataFlow() { return CurrentFrame()->is_dataflow; } + + /*! + * \brief Emits an Expr, and returns the variable it is bound to. + * \param expr The Expr to be emitted. + * \param name_hint Name hint for the bound variable. + * \note This Emit function normalizes the \p expr, and performs shape and type deductions by + * calling Normalize. + * \return The new variable that \p expr is bound to. + */ + virtual Var Emit(const Expr& expr, std::string name_hint = ""); + + /*! + * \brief Emits a variable binding, and returns the bound Var. + * \param binding The variable binding. + * \return The bound variable. + */ + virtual Var Emit(const VarBinding& binding); + + /*! + * \brief Emit a MatchShape. + * \param value The value of the MatchShape to be emitted. + * \param pattern The pattern of the MatchShape to be emitted. + * \param name_hint Name hint for the bound variable. + * \return The variable bound to the MatchShape. + */ + Var EmitMatchShape(const Expr& value, const Array& pattern, std::string name_hint = ""); + + /*! + * \brief Emit a MatchShape binding. + * \param binding The MatchShape binding to be emitted. + * \return The variable bound to the MatchShape. + */ + Var EmitMatchShape(const MatchShape& binding); + + /*! + * \brief Generate an output for the current dataflow block. + * \param output The output variable of the block. + * \param name_hint Name hint for the bound variable. + * \return The variable bound to \p output. + */ + Var EmitOutput(const Expr& output, std::string name_hint = ""); + + /*! + * \brief Generate an output for the current dataflow block. + * \param binding The output binding to output. + * \return The variable bound to \p output. + */ + Var EmitOutput(const VarBinding& binding); + + /*! + * \brief Lookup a var in the binding table \p binding_table_. + * \param var The input var. + * \return The Expr bound to the input \p var. + * \note For function parameters, this function returns NullOpt. + */ + Optional LookupBinding(const Var& var); + + /*! + * \brief Check if two shape expressions can be proven equal at compile time. + * \param lhs The input lhs shape. + * \param rhs The input rhs shape. + * \return Whether we can prove lhs shape is the same as the rhs shape. + */ + bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs); + + /*! + * \brief Convert an expression to A-normal form, and try to eagerly infer types and shapes. + * \param expr The input expression. + * \return The normalized expression. + */ + Expr Normalize(const Expr& expr); + + /*! + * \brief Get the name table for generating unique names. + * + * \return The name table. + */ + NameTable* name_table(); + + /*! + * \brief Add a Relax function or a TIR PrimFunc to \p context_mod_. + * \param func The function to be added. + * \param func_name_hint The name hint of the function to be added. + * \note If the function to be added already exists in \p context_mod_, return its + * GlobalVar directly. + * \return The global var bound to the added function. + */ + GlobalVar AddFunction(const BaseFunc& func, const String& func_name_hint); + + /*! + * \brief Update a Relax function or a TIR PrimFunc in \p context_mod_. + * \param gv The global var referring the function to be updated. + * \param function The updated function. + */ + void UpdateFunction(const GlobalVar& gv, BaseFunc function); + + /*! + * \brief Get the context IRModule being built. + * \return The IRModule being built by BlockBuilder. + */ + IRModule GetContextIRModule() const; + + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.BlockBuilder"; + TVM_DECLARE_BASE_OBJECT_INFO(BlockBuilderNode, Object); + + private: + /*! + * \brief Emits an Expr, and returns the variable it is bound to. + * \param expr The Expr to be emitted. + * \param is_dataflow Is the bound variable a DataflowVar or not(i.e. Var). + * \param name_hint Name hint for the bound variable. + * \note This Emit function normalizes the \p expr, and performs shape and type deductions by + * calling Normalize. + * \return The new variable that \p expr is bound to. + */ + Var Emit(const Expr& expr, bool is_dataflow, std::string name_hint); + + /*! \brief The IRModule being built by the BlockBuilder. */ + IRModule context_mod_; + + /*! + * \brief A hashmap to store the mapping of Relax functions and TIR PrimFuncs + * in \p _context_mod to their GlobalVar to avoid generating duplicated functions. + */ + std::unordered_map func_map_; + + protected: + /*! + * \brief A representation of a block frame. + * + * A block frame is a record containing the bindings needed + * to build a binding block, and a boolean to indicate if the + * block being built is a DataflowBlock or not. + */ + struct BlockFrame { + Array bindings; + bool is_dataflow; + }; + + /*! + * \brief Utility class for performing IR normalization (conversion to ANF, eager forward shape + * and type inference). + */ + class ExprNormalizer; + + friend class BlockBuilder; + + /*! + * \brief Get the current block frame. + * \return The current block frame. + */ + BlockFrame* CurrentFrame(); + + /*! \brief A stack to store block frames. */ + std::stack block_stack_; + + /*! \brief A diagnostic context for reporting errors. */ + DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {})); + + /*! \brief A binding table that maps var to value. */ + std::unordered_map binding_table_; + + /*! \brief A name table to get unique names for IR construction. */ + std::unique_ptr name_table_; + + /*! \brief The internal normalizer used for ANF conversion. */ + std::unique_ptr normalizer_; +}; + +class BlockBuilder : public ObjectRef { + public: + /*! + * \brief Create a BlockBuilder. + * \param mod Optional before-transformation IRModule for rewriting. + * \return The created BlockBuilder. + */ + TVM_DLL static BlockBuilder Create(Optional mod); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BLOCK_BUILDER_H_ diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h new file mode 100644 index 0000000000..e394e9ff53 --- /dev/null +++ b/include/tvm/relax/dataflow_matcher.h @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_matcher.h + * \brief A pattern matcher for matching dataflow properties. + */ +#ifndef TVM_RELAX_DATAFLOW_MATCHER_H_ +#define TVM_RELAX_DATAFLOW_MATCHER_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +/** + * \brief Determine if a pattern matches an expression. + * \note The behavior of MatchExpr is to match a relax.Expr (`expr`) syntactically through + * one given pattern (`pattern`). + * + * \param pattern The pattern to match + * \param expr The expression to match + * \param var2val The mapping from relax.Var to relax.Expr + * \return true if matched + * \return false if unmatched + */ +bool MatchExpr(DFPattern pattern, Expr expr, Optional> var2val = NullOpt); + +/** + * \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping. + * \note This algorithm returns the first matched sub-graph. Use `start_hint` to specify the + * starting point of the matching so that we can distinguish multiple matches. + * + * \param ctx The graph-wise patterns. + * \param dfb The function to match. + * \param start_hint The starting point expression to match to distinguish multiple matches. + * \param must_include_hint If start_hint is given, the return pattern must include start_hint. + * \return tvm::runtime::Map + */ +TVM_DLL tvm::runtime::Map MatchGraph(const PatternContext& ctx, + const DataflowBlock& dfb, + Optional start_hint = NullOpt, + bool must_include_hint = false); + +/** + * \brief Match a graph-wise pattern with the current context (PatternContext::Current()). + */ +inline tvm::runtime::Map MatchGraphDefault(const DataflowBlock& dfb, + Optional start_hint = NullOpt, + bool must_include_hint = false) { + return MatchGraph(PatternContext::Current(), dfb, start_hint, must_include_hint); +} + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_DATAFLOW_MATCHER_H_ diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h new file mode 100644 index 0000000000..6fbf0523d4 --- /dev/null +++ b/include/tvm/relax/dataflow_pattern.h @@ -0,0 +1,834 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_pattern.h + * \brief A pattern language for matching dataflow properties. + */ +#ifndef TVM_RELAX_DATAFLOW_PATTERN_H_ +#define TVM_RELAX_DATAFLOW_PATTERN_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class PatternSeq; +class CallPattern; +class OrPattern; +class AndPattern; +class NotPattern; +class ShapePattern; +class RuntimeDepShapePattern; +class TypePattern; +class DataTypePattern; +class AttrPattern; + +/*! + * \brief Create used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned. + * + * \param lhs Left hand side of the used-by relationship. + * \param rhs Right hand side of the used-by relationship. + * \param index lhs[-1] is used as the index'th argument of rhs[0]. + * \return PatternSeq The concatenated sequence of [*lhs, *rhs]. + */ +TVM_DLL PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1); +/*! \brief Syntax sugar of UsedBy(lhs, rhs, -1). */ +TVM_DLL PatternSeq operator^(const PatternSeq& lhs, const PatternSeq& rhs); + +/*! + * \brief Create only-used-by relationship between lhs[-1] and rhs[0], with [*lhs, *rhs] returned. + * + * \param lhs Left hand side of the used-by relationship. + * \param rhs Right hand side of the used-by relationship. + * \param index lhs[-1] is used as the index'th argument of rhs[0]. + * \return PatternSeq The concatenated sequence of [*lhs, *rhs]. + */ +TVM_DLL PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index = -1); +/*! \brief Syntax sugar of OnlyUsedBy(lhs, rhs, -1). */ +TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs); + +/*! + * \brief Base type of all dataflow patterns. + * \sa DFPattern + */ +class DFPatternNode : public Object { + public: + static constexpr const char* _type_key = "DFPatternNode"; + TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object); +}; + +/*! + * \brief Managed reference to dataflow patterns. + * \sa DFPatternNode + */ +class DFPattern : public ObjectRef { + public: + /*! \brief Syntatic Sugar for creating a CallPattern */ + template + CallPattern operator()(Args&&... args) const; + /*! \brief Syntatic Sugar for creating a CallPattern */ + TVM_DLL CallPattern operator()(const std::vector& args) const; + /*! \brief Syntatic Sugar for creating an OrPattern */ + TVM_DLL OrPattern operator|(const DFPattern& other) const; + /*! \brief Syntatic Sugar for creating an AndPattern */ + TVM_DLL AndPattern operator&(const DFPattern& other) const; + /*! \brief Syntatic Sugar for creating a NotPattern */ + TVM_DLL NotPattern operator~() const; + /*! \brief Syntatic Sugar for creating an AttrPattern */ + TVM_DLL AttrPattern HasAttr(const Map& attrs) const; + /*! \brief Syntatic Sugar for creating a TypePattern */ + TVM_DLL TypePattern HasType(const Type& type) const; + /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */ + TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const; + /*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */ + TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const; + /*! \brief Syntatic Sugar for creating a ShapePattern */ + TVM_DLL ShapePattern HasShape(const Array& shape) const; + /*! \brief Syntatic Sugar for creating a RuntimeDepShapePattern */ + TVM_DLL RuntimeDepShapePattern HasRuntimeDepShape() const; + /*! \brief Syntatic Sugar for duplicating the current pattern */ + TVM_DLL DFPattern dup() const; + + /*! \brief Implicit conversion from DFPattern to PatternSeq */ + TVM_DLL operator PatternSeq() const; + + TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode); +}; + +/*! \brief Constraint of a DFPattern edge (producer -> consumer) in graph-level matching */ +struct PairCons { + /*! \brief Constraint types of the edge */ + enum Type { + kUsedBy, /*!< producer ^ consumer */ + kOnlyUsedBy, /*!< producer >> consumer */ + } type = kUsedBy; + int index = -1; /*!< The argument index of the producer in the consumer caller site */ + + /*! + * \brief Construct a new PairCons object + * + * \param t The constraint type + * \param index The producer is called as the index'th argument of the consumer function. + */ + TVM_DLL explicit PairCons(Type t, int index = -1) : type(t), index(index) {} + + bool operator==(const PairCons& other) const { + return type == other.type && index == other.index; + } +}; + +/*! + * \brief A sequence of DFPatterns that the previous DFPattern is connected to the next one. + * \sa PatternSeq + */ +class PatternSeqNode final : public Object { + public: + tvm::Array patterns; /*!< The sequence of DFPatterns */ + std::vector pair_constraints; /*!< Constraints between the previous and next patterns */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("patterns", &patterns); } + static constexpr const char* _type_key = "relax.dpl.PatternSeq"; + TVM_DECLARE_BASE_OBJECT_INFO(PatternSeqNode, Object); +}; + +/*! + * \brief Managed reference to pattern sequences. + * \sa PatternSeqNode + */ +class PatternSeq final : public ObjectRef { + public: + TVM_DLL explicit PatternSeq(DFPattern init_pattern); + TVM_DLL explicit PatternSeq(tvm::Array patterns, bool only_used_by = false); + + PatternSeq UsedBy(PatternSeq other, int index = -1) const; + PatternSeq OnlyUsedBy(PatternSeq other, int index = -1) const; + + /*! \brief Syntatic Sugar for duplicating the current pattern sequence */ + PatternSeq dup() const; + + // friend functions + friend PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index); + friend PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index); + + TVM_DEFINE_OBJECT_REF_METHODS(PatternSeq, ObjectRef, PatternSeqNode); +}; + +/*! + * \brief A context to manage the graph-level pattern matching. + * \sa PatternContext + */ +class PatternContextNode : public Object { + public: + /*! \brief Constrainting matched graph with assertion to external uses */ + enum ExternUse { + kMay, /*!< No constraints */ + kMustNot, /*!< All nodes except outputs only have internal depedencies in the matched graph. */ + } allow_extern_use = kMay; + // src node -> constraints. + std::map>> constraints; + + static constexpr const char* _type_key = "relax.dpl.PatternContext"; + TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object); +}; + +/*! + * \brief Managed reference to a pattern context. + * \sa PatternContextNode + */ +class PatternContext : public ObjectRef { + public: + TVM_DLL explicit PatternContext(ObjectPtr n) : ObjectRef(n) {} + TVM_DLL explicit PatternContext(bool incremental = false); + + const PatternContextNode* operator->() const { + ICHECK(get() != nullptr); + return static_cast(get()); + } + + PatternContextNode* operator->() { + ICHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + /*! + * \brief Build an edge constraint between two patterns (producer and consumer). + * + * \param producer The pattern corresponding to the producer node. + * \param consumer The pattern corresponding to the consumer node. + * \param cons The constraint type. \sa PairCons + */ + void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) { + auto& vec = (*this)->constraints[producer][consumer]; + ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend()) << "Constraint already exists"; + vec.push_back(cons); + } + + /*! \brief Get the pass context object on the top of the stack */ + TVM_DLL static PatternContext Current(); + + class Internal; + + private: + /*! \brief The RAII-like entry of a pass context scope */ + TVM_DLL void EnterWithScope(); + /*! \brief The RAII-like exit of a pass context scope */ + TVM_DLL void ExitWithScope(); + friend class Internal; + friend class With; +}; + +/*! + * \brief Pattern for Relax Expression. + * \sa ExprPattern + */ +class ExprPatternNode : public DFPatternNode { + public: + Expr expr; /*!< The expression to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); } + + static constexpr const char* _type_key = "relax.dpl.ExprPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to an ExprPattern. + * \sa ExprPatternNode + */ +class ExprPattern : public DFPattern { + public: + TVM_DLL explicit ExprPattern(Expr expr); + TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Variable. + * \note The name field matches any string if it is empty. + * \sa VarPattern + */ +class VarPatternNode : public DFPatternNode { + public: + String name; + const String& name_hint() const { return name; } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); } + + static constexpr const char* _type_key = "relax.dpl.VarPattern"; + TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a VarPattern. + * \sa VarPatternNode + */ +class VarPattern : public DFPattern { + public: + /*! + * \brief Create a pattern matching by variable name. + * + * \param name_hint Variable name to match. Any if empty (""). + */ + TVM_DLL VarPattern(String name_hint); + TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Dataflow Variable + * \sa DataflowVarPattern + */ +class DataflowVarPatternNode : public VarPatternNode { + public: + static constexpr const char* _type_key = "relax.dpl.DataflowVarPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a DataflowVarPattern. + * \sa DataflowVarPatternNode + */ +class DataflowVarPattern : public DFPattern { + public: + /*! \sa VarPattern::VarPattern */ + TVM_DLL DataflowVarPattern(String name_hint); + TVM_DEFINE_OBJECT_REF_METHODS(DataflowVarPattern, DFPattern, DataflowVarPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Global Variable + * \sa GlobalVarPattern + */ +class GlobalVarPatternNode : public VarPatternNode { + public: + static constexpr const char* _type_key = "relax.dpl.GlobalVarPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a GlobalVarPattern. + * \sa GlobalVarPatternNode + */ +class GlobalVarPattern : public DFPattern { + public: + TVM_DLL GlobalVarPattern(String name_hint); + TVM_DEFINE_OBJECT_REF_METHODS(GlobalVarPattern, DFPattern, GlobalVarPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relax Constant. + * \sa ConstantPattern + */ +class ConstantPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relax.dpl.ConstantPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a ConstantPattern. + * \sa ConstantPatternNode + */ +class ConstantPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode); +}; + +/*! + * \brief A pattern to match a callable node in Relax. + * \sa CallPattern + */ +class CallPatternNode : public DFPatternNode { + public: + /*! + * \note The op field can be: + * - relay::Op which corresponds to the primitive operators. + * - user defined functions (Function, GlobalVar, Var). + */ + DFPattern op; /*!< The operator (function) being invoked */ + tvm::Array args; /*!< The arguments of the function call */ + /*! + * \note If varg_default_wildcard is true. Given args of [pA, pB], when matching a call whose + * arguments are [A, B, ...], the pattern will still match despite #args < #call.args. That said, + * with varg_default_wildcard set to true, we match the args in the order we have, and regard the + * rest of the arguments as wildcards. + */ + bool varg_default_wildcard; /*!< #args can be < #real args with the rest padded by Wildcard() */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("args", &args); + } + + static constexpr const char* _type_key = "relax.dpl.CallPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode); +}; + +class CallPattern : public DFPattern { + public: + TVM_DLL CallPattern(DFPattern op, Array args, bool varg_default_wildcard = false); + TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode); +}; + +/*! + * \brief A pattern to match an array of PrimExpr. + * \sa PrimArrPattern + * \note This is often used to match shapes specified as arguments to a function. + */ +class PrimArrPatternNode : public DFPatternNode { + public: + Array fields; /*!< The array to match */ + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + static constexpr const char* _type_key = "relax.dpl.PrimArrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrimArrPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to a PrimArrPattern. + * \sa PrimArrPatternNode + */ +class PrimArrPattern : public DFPattern { + public: + TVM_DLL PrimArrPattern(Array arr); + TVM_DEFINE_OBJECT_REF_METHODS(PrimArrPattern, DFPattern, PrimArrPatternNode); +}; + +/*! + * \brief A pattern to match a Relax Function + * \sa Function + * \sa FunctionPattern + */ +class FunctionPatternNode : public DFPatternNode { + public: + tvm::Array params; /*!< The parameters of the function */ + /*! + * \note Note that in Relax, the function body is a SeqExpr which contains + * 1) SeqExprNode::blocks, which is a list of blocks of statements; and 2) + * SeqExprNode::body, which is an Expr that can be anything. FunctionPattern + * only matches the body of the function (writing patterns to statements is tricky). + */ + DFPattern body; /*!< The body of the function */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "relax.dpl.FunctionPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to FunctionPatternNode. + * \sa FunctionPatternNode + */ +class FunctionPattern : public DFPattern { + public: + /*! + * \brief Constructor + * \param params The parameters of the function. + * \param body The body of the function. + */ + TVM_DLL FunctionPattern(tvm::Array params, DFPattern body); + + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode); +}; + +/*! + * \brief Pattern to match a tuple of ordered expressions. + * \sa TuplePattern + */ +class TuplePatternNode : public DFPatternNode { + public: + tvm::Array fields; /*!< The fields of the tuple */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + + static constexpr const char* _type_key = "relax.dpl.TuplePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to TuplePatternNode. + * \sa TuplePatternNode + */ +class TuplePattern : public DFPattern { + public: + TVM_DLL explicit TuplePattern(tvm::Array fields); + TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode); +}; + +/*! + * \brief A pattern to match multiple expressions unorderedly. + * \sa UnorderedTuplePattern + */ +class UnorderedTuplePatternNode : public DFPatternNode { + public: + tvm::Array fields; /*!< The fields of the tuple */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + + static constexpr const char* _type_key = "relax.dpl.UnorderedTuplePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(UnorderedTuplePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to UnorderedTuplePatternNode. + * \sa UnorderedTuplePatternNode + */ +class UnorderedTuplePattern : public DFPattern { + public: + TVM_DLL explicit UnorderedTuplePattern(tvm::Array fields); + TVM_DEFINE_OBJECT_REF_METHODS(UnorderedTuplePattern, DFPattern, UnorderedTuplePatternNode); +}; + +/*! + * \brief A pattern to match n'th indexing to a tuple. + * \sa TupleGetItem + * \sa TupleGetItemPattern + */ +class TupleGetItemPatternNode : public DFPatternNode { + public: + DFPattern tuple; /*!< The tuple Expression */ + int index; /*!< The index of the tuple with -1 meaning arbitrary */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tuple", &tuple); + v->Visit("index", &index); + } + + static constexpr const char* _type_key = "relax.dpl.TupleGetItemPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to TupleGetItemPatternNode. + * \sa TupleGetItemPatternNode + */ +class TupleGetItemPattern : public DFPattern { + public: + TVM_DLL TupleGetItemPattern(DFPattern tuple, int index); + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode); +}; + +/*! + * \brief Match a conjunction of other patterns. + * \sa AndPattern + */ +class AndPatternNode : public DFPatternNode { + public: + DFPattern left; /*!< The left hand side of the conjunction */ + DFPattern right; /*!< The right hand side of the conjunction */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("left", &left); + v->Visit("right", &right); + } + + static constexpr const char* _type_key = "relax.dpl.AndPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AndPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to AndPatternNode. + * \sa AndPatternNode + */ +class AndPattern : public DFPattern { + public: + TVM_DLL AndPattern(DFPattern lhs, DFPattern rhs); + TVM_DEFINE_OBJECT_REF_METHODS(AndPattern, DFPattern, AndPatternNode); +}; + +/*! + * \brief Match a disjunction of other patterns. + * \sa OrPattern + */ +class OrPatternNode : public DFPatternNode { + public: + DFPattern left; /*!< The left hand side of the disjunction */ + DFPattern right; /*!< The right hand side of the disjunction */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("left", &left); + v->Visit("right", &right); + } + + static constexpr const char* _type_key = "relax.dpl.OrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(OrPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to OrPatternNode. + * \sa OrPatternNode + */ +class OrPattern : public DFPattern { + public: + TVM_DLL OrPattern(DFPattern left, DFPattern right); + TVM_DEFINE_OBJECT_REF_METHODS(OrPattern, DFPattern, OrPatternNode); +}; + +/*! + * \brief Pattern for rejecting a certain pattern. + * \sa NotPattern + */ +class NotPatternNode : public DFPatternNode { + public: + DFPattern reject; /*!< The pattern to reject */ + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("reject", &reject); } + + static constexpr const char* _type_key = "relax.dpl.NotPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(NotPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to NotPatternNode. + * \sa NotPatternNode + */ +class NotPattern : public DFPattern { + public: + TVM_DLL NotPattern(DFPattern reject); + TVM_DEFINE_OBJECT_REF_METHODS(NotPattern, DFPattern, NotPatternNode); +}; + +/*! + * \brief Wildcard Pattern is a pattern that can match anything. + * \sa WildcardPattern + */ +class WildcardPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relax.dpl.WildcardPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to WildcardPatternNode. + * \sa WildcardPatternNode + */ +class WildcardPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); +}; + +/*! + * \brief Pattern for matching a certain type. + * \sa TypePattern + */ +class TypePatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The pattern to match */ + Type type; /*!< The type to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("type", &type); + } + + static constexpr const char* _type_key = "relax.dpl.TypePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to TypePatternNode. + * \sa TypePatternNode + */ +class TypePattern : public DFPattern { + public: + TVM_DLL TypePattern(DFPattern pattern, Type type); + TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode); +}; + +/*! + * \brief A pattern that asserting a root pattern has a certain shape. + * \sa ShapePattern + */ +class ShapePatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The root pattern to match */ + Array shape; /*!< The shape to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("shape", &shape); + } + + static constexpr const char* _type_key = "relax.dpl.ShapePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to ShapePatternNode. + * \sa ShapePatternNode + */ +class ShapePattern : public DFPattern { + public: + TVM_DLL ShapePattern(DFPattern pattern, Array type); + TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode); +}; + +/*! + * \brief A pattern that asserting a root pattern has a certain data type. + * \sa DataTypePattern + */ +class DataTypePatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The root pattern to match */ + DataType dtype; /*!< The data type to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("dtype", &dtype); + } + + static constexpr const char* _type_key = "relax.dpl.DataTypePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to DataTypePatternNode. + * \sa DataTypePatternNode + */ +class DataTypePattern : public DFPattern { + public: + TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype); + TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode); +}; + +/*! + * \brief A pattern that asserting a root pattern has certain attributes. + * \sa AttrPattern + */ +class AttrPatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The root pattern to match */ + DictAttrs attrs; /*!< The attributes (a map/dictionary) to match */ + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("attrs", &attrs); + } + + static constexpr const char* _type_key = "relax.dpl.AttrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to AttrPatternNode. + * \sa AttrPatternNode + */ +class AttrPattern : public DFPattern { + public: + TVM_DLL AttrPattern(DFPattern pattern, DictAttrs attrs); + TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode); +}; + +/*! + * \brief A pattern of external function. + * \sa ExternFunc + * \sa ExternFuncPattern + */ +class ExternFuncPatternNode : public DFPatternNode { + public: + String global_symbol_; /*!< The global symbol name of the external function */ + + /*! \brief The the external function name */ + const String& global_symbol() const { return global_symbol_; } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("global_symbol", &global_symbol_); } + + static constexpr const char* _type_key = "relax.dpl.ExternFuncPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncPatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to ExternFuncPatternNode. + * \sa ExternFuncPatternNode + */ +class ExternFuncPattern : public DFPattern { + public: + TVM_DLL ExternFuncPattern(String global_symbol); + TVM_DEFINE_OBJECT_REF_METHODS(ExternFuncPattern, DFPattern, ExternFuncPatternNode); +}; + +/*! + * \brief A pattern that asserting a root pattern has a runtime-dependent shape. + * \sa RuntimeDepShape + * \sa RuntimeDepShapePattern + */ +class RuntimeDepShapePatternNode : public DFPatternNode { + public: + DFPattern pattern; /*!< The root pattern to match */ + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relax.dpl.RuntimeDepShapePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeDepShapePatternNode, DFPatternNode); +}; + +/*! + * \brief Managed reference to RuntimeDepShapePatternNode. + * \sa RuntimeDepShapePatternNode + */ +class RuntimeDepShapePattern : public DFPattern { + public: + TVM_DLL explicit RuntimeDepShapePattern(DFPattern pattern); + TVM_DEFINE_OBJECT_REF_METHODS(RuntimeDepShapePattern, DFPattern, RuntimeDepShapePatternNode); +}; + +/*! \brief Syntatic Sugar for creating a VarPattern with a name */ +VarPattern IsVar(const String& name); +/*! \brief Syntatic Sugar for creating a ConstantPattern */ +ConstantPattern IsConst(); +/*! \brief Syntatic Sugar for creating a WildcardPattern */ +WildcardPattern Wildcard(); +/*! \brief Syntatic Sugar for creating a ExprPattern */ +ExprPattern IsExpr(const Expr& expr); +/*! \brief Syntatic Sugar for creating a ExprPattern base on an Op */ +ExprPattern IsOp(const String& op_name); +/*! \brief Syntatic Sugar for call_tir (return a tensor) */ +CallPattern IsCallTIR(const String& name, Optional args = NullOpt, + Optional> oshape = NullOpt); +/*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */ +CallPattern IsCallTIR(const String& name, TuplePattern var_args, Array> oshapes); +/*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */ +DFPattern IsTuple(const Array& fields, bool unordered = false); +/*! \brief Syntatic Sugar for creating a TupleGetItemPattern */ +TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index = -1); + +/*! \brief Implementation of the templated CallPattern syntax sugar */ +template +CallPattern DFPattern::operator()(Args&&... args) const { + return CallPattern(GetRef(this->get()), + Array({std::forward(args)...})); +} + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_DATAFLOW_PATTERN_H_ diff --git a/include/tvm/relax/dataflow_pattern_functor.h b/include/tvm/relax/dataflow_pattern_functor.h new file mode 100644 index 0000000000..4ac5fe1173 --- /dev/null +++ b/include/tvm/relax/dataflow_pattern_functor.h @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_pattern_functor.h + * \brief Functors and visitors for dataflow patterns. + */ +#ifndef TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_ +#define TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_ + +#include + +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A dynamical functor that dispatches on in the first DFPattern argument. + * + * \tparam FType function signature + * This type is only defined for FType with function signature R(const DFPattern&, + * Args...) + */ +template +class DFPatternFunctor; + +// functions to be overriden. +#define DFPATTERN_FUNCTOR_DEFAULT \ + { return VisitDFPatternDefault_(op, std::forward(args)...); } + +#define RELAX_DFPATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitDFPattern_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class DFPatternFunctor { + private: + using TSelf = DFPatternFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief virtual destructor */ + virtual ~DFPatternFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const DFPattern& n, Args... args) { + return VisitDFPattern(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitDFPattern(const DFPattern& n, Args... args) { + ICHECK(n.defined()); + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitDFPattern_(const OrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const AndPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const NotPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + + virtual R VisitDFPattern_(const RuntimeDepShapePatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const DataflowVarPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const GlobalVarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ExternFuncPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const PrimArrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const UnorderedTuplePatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + + virtual R VisitDFPatternDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAX_DFPATTERN_FUNCTOR_DISPATCH(OrPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(AndPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(NotPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); + + RELAX_DFPATTERN_FUNCTOR_DISPATCH(RuntimeDepShapePatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(GlobalVarPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExternFuncPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(PrimArrPatternNode); + RELAX_DFPATTERN_FUNCTOR_DISPATCH(UnorderedTuplePatternNode); + return vtable; + } +}; + +/*! + * \brief A simple visitor wrapper around DFPatternFunctor. + * Recursively visit the content. + * + * DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once. + */ +class DFPatternVisitor : public DFPatternFunctor { + public: + void VisitDFPattern(const DFPattern& pattern) override; + void VisitDFPattern_(const OrPatternNode* op) override; + void VisitDFPattern_(const AndPatternNode* op) override; + void VisitDFPattern_(const NotPatternNode* op) override; + void VisitDFPattern_(const AttrPatternNode* op) override; + void VisitDFPattern_(const CallPatternNode* op) override; + void VisitDFPattern_(const ConstantPatternNode* op) override; + void VisitDFPattern_(const DataTypePatternNode* op) override; + void VisitDFPattern_(const ExprPatternNode* op) override; + void VisitDFPattern_(const FunctionPatternNode* op) override; + void VisitDFPattern_(const ShapePatternNode* op) override; + void VisitDFPattern_(const TupleGetItemPatternNode* op) override; + void VisitDFPattern_(const TuplePatternNode* op) override; + void VisitDFPattern_(const TypePatternNode* op) override; + void VisitDFPattern_(const WildcardPatternNode* op) override; + void VisitDFPattern_(const VarPatternNode* op) override; + + void VisitDFPattern_(const RuntimeDepShapePatternNode* op) override; + void VisitDFPattern_(const DataflowVarPatternNode* op) override; + void VisitDFPattern_(const GlobalVarPatternNode* op) override; + void VisitDFPattern_(const ExternFuncPatternNode* op) override; + void VisitDFPattern_(const PrimArrPatternNode* op) override; + void VisitDFPattern_(const UnorderedTuplePatternNode* op) override; + + protected: + // set of already-visited nodes + std::unordered_set visited_; +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_DATAFLOW_PATTERN_FUNCTOR_H_ diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h new file mode 100644 index 0000000000..0f0fad82e2 --- /dev/null +++ b/include/tvm/relax/exec_builder.h @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/exec_builder.h + */ +#ifndef TVM_RELAX_EXEC_BUILDER_H_ +#define TVM_RELAX_EXEC_BUILDER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +namespace vm = tvm::runtime::relax_vm; + +class ExecBuilder; + +/*! + * \brief A builder provides api to build VM executable with instructions. + */ +class ExecBuilderNode : public Object { + public: + /*! \brief The mutable internal executable. */ + ObjectPtr exec; // mutable + /*! + * \brief To annotate the start of a vm function. + * \param func The function name. + * \param num_inputs The number of inputs. + * \param param_names The function parameter names. + */ + void EmitFunction(std::string func, int64_t num_inputs, Array param_names); + /*! + * \brief Emit a call instruction for a packed function. + * \param func The packed function name. + * \param args The arguments of the function. + * \param ret The return register. + */ + void EmitCall(std::string func, std::vector args, vm::RegName ret); + /*! + * \brief Emit a ret instruction. + * \param result The return result. + */ + void EmitRet(vm::RegName result); + /*! + * \brief Emit a goto instruction. + * \param pc_offset The program counter offset as the jump offset. + */ + void EmitGoto(vm::Index pc_offset); + /*! + * \brief Emit an If instruction. + * \param cond The register containing the cond value. + * \param false_offset The program counter offset for the false branch. + */ + void EmitIf(vm::RegName cond, vm::Index false_offset); + /*! + * \brief Emit a constant value to the constant pool. + * \param obj The constant value to be emitted + * \return The index that represents the constant. + */ + vm::Index EmitConstant(TVMRetValue obj); + /*! + * \brief Get the built executable. + * \return The built executable. + */ + ObjectPtr Get(); + /*! + * \brief Create an ExecBuilder. + * \return The ExecBuilder. + */ + TVM_DLL static ExecBuilder Create(); + + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.ExecBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExecBuilderNode, Object); + + private: + /*! + * \brief A helper function to check if an executable is legal by checking if registers are used + * properly + */ + void CheckExecutable(); + /*! + * \brief Formalize the executable. + */ + void Formalize(); +}; + +class ExecBuilder : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecBuilder, ObjectRef, ExecBuilderNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_EXEC_BUILDER_H_ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h new file mode 100644 index 0000000000..6c2d8fda52 --- /dev/null +++ b/include/tvm/relax/expr.h @@ -0,0 +1,539 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAX_EXPR_H_ +#define TVM_RELAX_EXPR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using Expr = RelayExpr; +using ExprNode = RelayExprNode; +using relay::Call; +using relay::CallNode; +using relay::Constant; +using relay::ConstantNode; +using relay::Id; +using relay::If; +using relay::IfNode; +using relay::Tuple; +using relay::TupleGetItem; +using relay::TupleGetItemNode; +using relay::TupleNode; + +/*! \brief A shape expression which allows users to construct a shape containing PrimExpr. + */ +class ShapeExprNode : public ExprNode { + public: + /*! The values of the shape expression. */ + Array values; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("values", &values); + v->Visit("shape_", &shape_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const { + return equal(values, other->values) && equal(checked_type_, other->checked_type_) && + equal(shape_, other->shape_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(values); + hash_reduce(checked_type_); + hash_reduce(shape_); + } + + static constexpr const char* _type_key = "relax.expr.ShapeExpr"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeExprNode, ExprNode); +}; + +class ShapeExpr : public Expr { + public: + TVM_DLL explicit ShapeExpr(Array values, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, Expr, ShapeExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode); +}; + +/*! \brief Runtime dependent shape expression. + * + * Sometimes shape of a tensor cannot be deduced statically either because the shape is truly data + * dependent such as output of `unique` operator or cannot be deduced because of limited shape + * inference capability. + */ +class RuntimeDepShapeNode : public ExprNode { + public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("shape_", &shape_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const RuntimeDepShapeNode* other, SEqualReducer equal) const { + return equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(checked_type_); + hash_reduce(shape_); + } + + static constexpr const char* _type_key = "relax.expr.RuntimeDepShape"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeDepShapeNode, ExprNode); +}; + +class RuntimeDepShape : public Expr { + public: + TVM_DLL explicit RuntimeDepShape(Span span = Span()); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RuntimeDepShape, Expr, RuntimeDepShapeNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(RuntimeDepShapeNode); +}; + +/*! \brief The variable class for all Relax bindings. */ +class VarNode : public ExprNode { + public: + /*! \brief The identifier of the variable, which is used for comparing stable equality across + * transformations. */ + Id vid; + + /*! \return The name hint of the variable */ + const String& name_hint() const { return vid->name_hint; } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("_checked_type_", &checked_type_); + v->Visit("vid", &vid); + v->Visit("span", &span); + v->Visit("shape_", &shape_); + } + + bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(vid, other->vid) && equal(checked_type_, other->checked_type_) && + equal(shape_, other->shape_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(vid); + hash_reduce(shape_); + hash_reduce(checked_type_); + } + + static constexpr const char* _type_key = "relax.expr.Var"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 2; + TVM_DECLARE_BASE_OBJECT_INFO(VarNode, ExprNode); +}; + +class Var : public Expr { + public: + TVM_DLL explicit Var(String name_hint, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()) + : Var(Id(name_hint), shape_annotation, type_annotation, span) {} + + TVM_DLL explicit Var(Id vid, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); +}; + +/*! \brief A sub-type of the variable node used to mark dataflow variables from + * normal visible "function local" bindings. + */ +class DataflowVarNode : public VarNode { + public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("vid", &vid); + v->Visit("span", &span); + v->Visit("shape_", &shape_); + v->Visit("_checked_type_", &checked_type_); + } + + bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(vid, other->vid) && equal(shape_, other->shape_) && + equal(checked_type_, other->checked_type_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(vid); + hash_reduce(shape_); + hash_reduce(checked_type_); + } + + static constexpr const char* _type_key = "relax.expr.DataflowVar"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarNode, VarNode); +}; + +class DataflowVar : public Var { + public: + TVM_DLL explicit DataflowVar(String name_hint, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()) + : DataflowVar(Id(name_hint), shape_annotation, type_annotation, span) {} + + TVM_DLL explicit DataflowVar(Id vid, runtime::Optional shape_annotation, + runtime::Optional type_annotation, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode); +}; + +/*! \brief The base class of a variable binding in Relax. */ +class BindingNode : public Object { + public: + mutable Span span; + + void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } + bool SEqualReduce(const BindingNode* other, SEqualReducer equal) const { return true; } + void SHashReduce(SHashReducer hash_reduce) const {} + + static constexpr const char* _type_key = "relax.expr.Binding"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object); +}; + +class Binding : public ObjectRef { + protected: + Binding() = default; + + public: + explicit Binding(ObjectPtr n) : ObjectRef(n) {} + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding); + const BindingNode* operator->() const { return static_cast(data_.get()); } + const BindingNode* get() const { return operator->(); } + using ContainerType = BindingNode; +}; + +/*! \brief Symbolic shape match, binds the variable of the lhs with the rhs. */ +class MatchShape; +class MatchShapeNode : public BindingNode { + public: + Expr value; + Array pattern; + Var var; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("pattern", &pattern); + v->Visit("var", &var); + v->Visit("span", &span); + } + + bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const { + // NOTE: pattern can contain ShapeExpr which defines the vars + return equal(value, other->value) && equal.DefEqual(pattern, other->pattern) && + equal.DefEqual(var, other->var); + } + + void SHashReduce(SHashReducer hash_reduce) const { + // NOTE: pattern can contain ShapeExpr which defines the vars + hash_reduce(value); + hash_reduce.DefHash(pattern); + hash_reduce.DefHash(var); + } + + static constexpr const char* _type_key = "relax.expr.MatchShape"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(MatchShapeNode, BindingNode); +}; + +class MatchShape : public Binding { + public: + TVM_DLL explicit MatchShape(Expr value, Array pattern, Var var, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchShapeNode); +}; + +class VarBinding; +class VarBindingNode : public BindingNode { + public: + Var var; + Expr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("span", &span); + } + + bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { + return equal.DefEqual(var, other->var) && equal(value, other->value); + } + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(var); + hash_reduce(value); + } + static constexpr const char* _type_key = "relax.expr.VarBinding"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(VarBindingNode, BindingNode); +}; + +class VarBinding : public Binding { + public: + TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode); +}; + +class BindingBlock; + +class BindingBlockNode : public Object { + public: + mutable Span span; + Array bindings; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("span", &span); + v->Visit("bindings", &bindings); + } + + bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const { + return equal(bindings, other->bindings); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + + static constexpr const char* _type_key = "relax.expr.BindingBlock"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingBlockNode, Object); +}; + +class BindingBlock : public ObjectRef { + public: + TVM_DLL explicit BindingBlock(Array bindings, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode); +}; + +class DataflowBlock; +class DataflowBlockNode : public BindingBlockNode { + public: + bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { + return equal(bindings, other->bindings); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + + static constexpr const char* _type_key = "relax.expr.DataflowBlock"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockNode, BindingBlockNode); +}; + +class DataflowBlock : public BindingBlock { + public: + TVM_DLL explicit DataflowBlock(Array bindings, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode); +}; + +/*! \brief A sequence of blocks followed by an expression. + * + * The order of blocks enforces scoping and ordering. + */ +class SeqExprNode : public ExprNode { + public: + Array blocks; + Expr body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("blocks", &blocks); + v->Visit("body", &body); + v->Visit("shape_", &shape_); + v->Visit("_checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const SeqExprNode* other, SEqualReducer equal) const { + return equal(blocks, other->blocks) && equal(body, other->body) && + equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(blocks); + hash_reduce(body); + hash_reduce(shape_); + hash_reduce(checked_type_); + } + + static constexpr const char* _type_key = "relax.expr.SeqExpr"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(SeqExprNode, ExprNode); +}; + +class SeqExpr : public Expr { + public: + TVM_DLL explicit SeqExpr(Array blocks, Expr body, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode); +}; + +/*! \brief A Relax function, eventually to replace the current Relay function definition. */ +class FunctionNode : public BaseFuncNode { + public: + /*! \brief The parameters to the function. */ + Array params; + /*! \brief The body of the function. */ + Expr body; + /*! \brief The return type of the function. */ + Type ret_type; + /*! \brief The return shape of the function. */ + Expr ret_shape; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("body", &body); + v->Visit("ret_type", &ret_type); + v->Visit("ret_shape", &ret_shape); + v->Visit("_checked_type_", &checked_type_); + v->Visit("shape_", &shape_); + v->Visit("span", &span); + v->Visit("attrs", &attrs); + } + + bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal.DefEqual(params, other->params) && equal(body, other->body) && + equal(ret_type, other->ret_type) && equal(ret_shape, other->ret_shape) && + equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_) && + equal(attrs, other->attrs); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce.DefHash(params); + hash_reduce(body); + hash_reduce(ret_type); + hash_reduce(ret_shape); + hash_reduce(checked_type_); + hash_reduce(shape_); + hash_reduce(attrs); + } + + static constexpr const char* _type_key = "relax.expr.Function"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); +}; + +class Function : public BaseFunc { + public: + TVM_DLL explicit Function(Array params, Expr body, Type ret_type, Expr ret_shape, + DictAttrs attrs = NullValue(), Span span = Span()); + + /*! + * \brief Mimics the constructor but without type checking. + */ + TVM_DLL static Function CreateUnchecked(Array params, Expr body, Type ret_type, + Expr ret_shape, DictAttrs attrs = NullValue(), + Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); +}; + +// TODO(@sunggg): Investigate the exact usage of kComposite, kPartitionedFromPattern, and +// kPrimitive. +namespace attr { +/*! \brief Mark the function as a primitive function. */ +constexpr const char* kPrimitive = "Primitive"; +/*! + * \brief Indicate the codegen that should be used for building this function. + * When this is unset or set to "default", the default compilation pipeline will be used. + */ +constexpr const char* kCodegen = "Codegen"; +/*! \brief Treat the function as a composite operator. */ +constexpr const char* kComposite = "Composite"; +/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ +constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; +} // namespace attr + +/*! \brief The extern function, which can represent packed function. */ +class ExternFuncNode : public BaseFuncNode { + public: + /*! \brief The name of global symbol. */ + String global_symbol; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("global_symbol", &global_symbol); + v->Visit("span", &span); + } + + bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const { + return equal(global_symbol, other->global_symbol); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(global_symbol); } + + static constexpr const char* _type_key = "relax.expr.ExternFunc"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode); +}; + +class ExternFunc : public BaseFunc { + public: + TVM_DLL ExternFunc(String global_symbol, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, BaseFunc, ExternFuncNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode); +}; + +/*! + * \brief Update the type of an Expr. + * \param expr The Expr whose type to be updated. + * \param type The type assigned to the checked_type_ of \p expr. + * \note We ensure idempotence, that is we can only update the checked_type_ of an Expr if it's + * nullptr. + */ +void UpdateType(Expr expr, Type type); + +/*! + * \brief Update the shape of an Expr. + * \param expr The Expr whose shape to be updated. + * \param shape The shape assigned to the shape_ of \p expr. + * \note We ensure idempotence, that is we can only update the shape_ of an Expr if it's nullptr. + */ +void UpdateShape(Expr expr, Optional shape); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_EXPR_H_ diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h new file mode 100644 index 0000000000..29b449c033 --- /dev/null +++ b/include/tvm/relax/expr_functor.h @@ -0,0 +1,881 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/expr_functor.h + * \brief A more powerful visitor which enables defining arbitrary function + * signatures with type based dispatch on first argument. + */ +#ifndef TVM_RELAX_EXPR_FUNCTOR_H_ +#define TVM_RELAX_EXPR_FUNCTOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +namespace tvm { +namespace relax { + +/*! + * \brief A dynamical functor that dispatches on in the first Expr argument. + * You can use this as a more powerful Visitor, since it allows you to + * define function signatures of Visit Function. + * + * \sa tvm/ir_functor.h + * + * \tparam FType function signiture + * This type is only defined for FType with function signature R(const Expr&, + * Args...) + */ +template +class ExprFunctor; + +// functions to be overriden. +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } + +#define RELAX_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ + }); + +#define PY_EXPR_VISITOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC) \ + { \ + if (PY_FUNC != nullptr) \ + PY_FUNC(N); \ + else \ + DEFAULT_FUNC; \ + } + +#define PY_EXPR_MUTATOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC, RET_TYPE) \ + { \ + if (PY_FUNC != nullptr) { \ + RET_TYPE ret = PY_FUNC(N); \ + return ret; \ + } else { \ + return DEFAULT_FUNC; \ + } \ + } + +#define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + if (self->PY_FUNC != nullptr) \ + self->PY_FUNC(n); \ + else \ + self->VisitExpr_(static_cast(n.get())); \ + }); + +#define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + if (self->PY_FUNC != nullptr) { \ + Expr expr = self->PY_FUNC(n); \ + return expr; \ + } else { \ + return self->VisitExpr_(static_cast(n.get())); \ + } \ + }); + +#define PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OP) \ + post_order_vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + return self->VisitExprPostOrder_(static_cast(n.get())); \ + }); + +template +class ExprFunctor { + private: + using TSelf = ExprFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~ExprFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward(args)...); } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitExpr(const Expr& n, Args... args) { + ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " + "have generated invalid data."; + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const DataflowVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ShapeExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const RuntimeDepShapeNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ExternFuncNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SeqExprNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAX_EXPR_FUNCTOR_DISPATCH(ConstantNode); + RELAX_EXPR_FUNCTOR_DISPATCH(TupleNode); + RELAX_EXPR_FUNCTOR_DISPATCH(VarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(DataflowVarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(ShapeExprNode); + RELAX_EXPR_FUNCTOR_DISPATCH(RuntimeDepShapeNode); + RELAX_EXPR_FUNCTOR_DISPATCH(ExternFuncNode); + RELAX_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); + RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode); + RELAX_EXPR_FUNCTOR_DISPATCH(CallNode); + RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode); + RELAX_EXPR_FUNCTOR_DISPATCH(IfNode); + RELAX_EXPR_FUNCTOR_DISPATCH(OpNode); + RELAX_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); + return vtable; + } +}; + +/*! + * \brief A simple visitor wrapper around ExprFunctor. + * Recursively visit the content. + */ +class ExprVisitor : public ExprFunctor { + public: + /*! + * \brief Generic dispatcher for Expr. + * \param expr The expr to be visited. + */ + void VisitExpr(const Expr& expr) override; + // specific leaf level visitor functions + void VisitExpr_(const ConstantNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const DataflowVarNode* op) override; + void VisitExpr_(const ShapeExprNode* op) override; + void VisitExpr_(const RuntimeDepShapeNode* op) override; + void VisitExpr_(const ExternFuncNode* op) override; + void VisitExpr_(const GlobalVarNode* op) override; + void VisitExpr_(const FunctionNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const SeqExprNode* op) override; + void VisitExpr_(const IfNode* op) override; + void VisitExpr_(const OpNode* op) override; + void VisitExpr_(const TupleGetItemNode* op) override; + + /*! + * \brief Generic dispatcher for bindings. + * \param binding The binding to be visited. + */ + virtual void VisitBinding(const Binding& binding); + // specific leaf level visitor functions + virtual void VisitBinding_(const VarBindingNode* binding); + virtual void VisitBinding_(const MatchShapeNode* binding); + + /*! + * \brief Generic dispatcher for binding blocks. + * \param block The binding block to be visited. + */ + virtual void VisitBindingBlock(const BindingBlock& block); + // specific leaf level visitor functions + virtual void VisitBindingBlock_(const BindingBlockNode* block); + virtual void VisitBindingBlock_(const DataflowBlockNode* block); + + /*! + * \brief Generic dispatcher for visiting the var definition site. + * \param var The var to be visited. + * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var + */ + virtual void VisitVarDef(const Var& var); + // specific leaf level visitor functions + virtual void VisitVarDef_(const VarNode* var); + virtual void VisitVarDef_(const DataflowVarNode* var); + + virtual void VisitType(const Type& t); + virtual void VisitSpan(const Span& span); +}; + +void PostOrderVisit(const Expr& node, std::function fvisit); + +/*! + * \brief A mutator works in unnormalized form. + * + * ExprMutatorBase expects input AST to be in the unnormalized form, i.e., checked_type_ and shape_ + * of expressions can be nullptr, and the expressions may nest(and as a result the AST is not in + * ANF). + */ + +class ExprMutatorBase : public ExprFunctor { + public: + Expr VisitExpr(const Expr& expr) override; + Expr VisitExpr_(const ConstantNode* op) override; + Expr VisitExpr_(const TupleNode* op) override; + Expr VisitExpr_(const VarNode* op) override; + Expr VisitExpr_(const DataflowVarNode* op) override; + Expr VisitExpr_(const ShapeExprNode* op) override; + Expr VisitExpr_(const RuntimeDepShapeNode* op) override; + Expr VisitExpr_(const ExternFuncNode* op) override; + Expr VisitExpr_(const GlobalVarNode* op) override; + Expr VisitExpr_(const FunctionNode* op) override; + Expr VisitExpr_(const CallNode* op) override; + Expr VisitExpr_(const SeqExprNode* op) override; + Expr VisitExpr_(const IfNode* op) override; + Expr VisitExpr_(const OpNode* op) override; + Expr VisitExpr_(const TupleGetItemNode* op) override; + + /*! + * \brief Mutate BindingBlock. + * \param block The binding block to be visited. + * \return The binding block after transformation. + */ + virtual BindingBlock VisitBindingBlock(const BindingBlock& block); + + /*! + * \brief Used to visit the types inside of expressions. + * + * Can be overloaded to transform the types in arbitrary + * ways, one way would be to define a sub-class of type + * visitor for types which transform them appropriately. + */ + virtual Type VisitType(const Type& t); +}; + +/*! + * \brief A mutator works in normal form. + * + * ExprMutator expects input AST to be in the normal form, i.e., the expressions are normalized(no + * nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are + * available. + */ +class ExprMutator : public ExprMutatorBase { + public: + using ExprMutatorBase::VisitExpr_; + + ExprMutator(Optional mod = NullOpt) { builder_ = BlockBuilder::Create(mod); } + Expr VisitExpr(const Expr& expr) override; + Expr VisitExpr_(const TupleNode* op) override; + Expr VisitExpr_(const VarNode* op) override; + Expr VisitExpr_(const DataflowVarNode* op) override; + Expr VisitExpr_(const FunctionNode* op) override; + Expr VisitExpr_(const SeqExprNode* op) override; + Expr VisitExpr_(const IfNode* op) override; + + /*! + * \brief Generic dispatcher for bindings. + * \param binding The binding to be visited. + */ + virtual void VisitBinding(const Binding& binding); + // specific leaf level visitor functions + virtual void VisitBinding_(const VarBindingNode* binding); + virtual void VisitBinding_(const MatchShapeNode* binding); + + /*! + * \brief Generic dispatcher for binding blocks. + * \param block The binding block to be visited. + * \return The binding block after transformation. + */ + virtual BindingBlock VisitBindingBlock(const BindingBlock& block) override; // NOLINT(*) + // specific leaf level visitor functions + virtual BindingBlock VisitBindingBlock_(const BindingBlockNode* block); + virtual BindingBlock VisitBindingBlock_(const DataflowBlockNode* block); + + /*! + * \brief Generic dispatcher for rewriting the var definition site. + * \param var The var to be visited. + * \return The var after post-order rewritten. + * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var + */ + virtual Var VisitVarDef(const Var& var); + // specific leaf level visitor functions + virtual Var VisitVarDef_(const VarNode* var); + virtual Var VisitVarDef_(const DataflowVarNode* var); + + protected: + class ExprNormalizer; + + /*! + * \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If. + * \param expr The expr to be visited. + * \return The expr after visiting. + */ + Expr VisitWithNewScope(const Expr& expr); + + /*! + * \brief Look up the value bound to a variable. + * \param var The var to be looked up. + * \return The value bound to the input \p var. + * \note For function parameters, this function returns NullOpt. + */ + Optional LookupBinding(const Var& var); + + /*! + * \brief Post-order rewrite a node and normalize. + * \param T The node type to be rewritten. + * \param op The node to be rewritten. + * \return The node after post rewritten. + */ + template + Expr VisitExprPostOrder_(const T* op) { + return builder_->Normalize(ExprMutator::VisitExpr_(op)); + } + + /*! + * \brief Create a new var with specified shape and type if the original var's shape or type does + * not match with the specified ones. + * \param var The var to be updated. + * \param shape The specified shape. + * \param type The specified type. + * \return The var filled with \p shape and \p type. + */ + Var WithShapeAndType(Var var, Optional shape, Type type); + + /*! \brief Internal block builder to emit bindings during rewriting. */ + BlockBuilder builder_; + + /*! \brief Remap a var to a new var in use-site. */ + std::unordered_map var_remap_; +}; + +/*! + * \brief The abstract interface of ExprVisitor. + */ +class PyExprVisitorNode : public Object, public ExprVisitor { + private: + using TSelf = PyExprVisitorNode; + using FType = tvm::NodeFunctor; + + public: + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + PackedFunc f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ + PackedFunc f_visit_constant_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ + PackedFunc f_visit_tuple_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + PackedFunc f_visit_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ + PackedFunc f_visit_dataflow_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ + PackedFunc f_visit_shape_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const RuntimeDepShapeNode* op)` function. */ + PackedFunc f_visit_runtime_dep_shape_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ + PackedFunc f_visit_extern_func_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ + PackedFunc f_visit_global_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ + PackedFunc f_visit_function_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + PackedFunc f_visit_call_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ + PackedFunc f_visit_seq_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ + PackedFunc f_visit_if_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ + PackedFunc f_visit_op_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ + PackedFunc f_visit_tuple_getitem_{nullptr}; + /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ + PackedFunc f_visit_binding{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` + * function. */ + PackedFunc f_visit_var_binding_{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)` + * function. */ + PackedFunc f_visit_match_shape_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` + * function. */ + PackedFunc f_visit_binding_block{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` + * function. */ + PackedFunc f_visit_binding_block_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` + * function. */ + PackedFunc f_visit_dataflow_block_{nullptr}; + /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ + PackedFunc f_visit_var_def{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ + PackedFunc f_visit_var_def_{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ + PackedFunc f_visit_dataflow_var_def_{nullptr}; + /*! \brief The packed function to the `VisitType(const Type& t)` function. */ + PackedFunc f_visit_type{nullptr}; + /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ + PackedFunc f_visit_span{nullptr}; + + void VisitExpr(const Expr& expr) { + if (f_visit_expr != nullptr) { + f_visit_expr(expr); + } else { + // Need to init the overwrite VTable + static FType vtable = InitVTable(); + vtable(expr, this); + } + } + + void VisitBinding(const Binding& binding) + PY_EXPR_VISITOR_DEFAULT(binding, f_visit_binding, ExprVisitor::VisitBinding(binding)); + + void VisitBinding_(const VarBindingNode* binding) + PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_var_binding_, + ExprVisitor::VisitBinding_(binding)); + void VisitBinding_(const MatchShapeNode* binding) + PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_match_shape_, + ExprVisitor::VisitBinding_(binding)); + + void VisitBindingBlock(const BindingBlock& block) + PY_EXPR_VISITOR_DEFAULT(block, f_visit_binding_block, ExprVisitor::VisitBindingBlock(block)); + + void VisitBindingBlock_(const BindingBlockNode* block) + PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_binding_block_, + ExprVisitor::VisitBindingBlock_(block)); + void VisitBindingBlock_(const DataflowBlockNode* block) + PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + ExprVisitor::VisitBindingBlock_(block)); + + void VisitVarDef(const Var& var) + PY_EXPR_VISITOR_DEFAULT(var, f_visit_var_def, ExprVisitor::VisitVarDef(var)); + void VisitVarDef_(const VarNode* var) + PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_var_def_, ExprVisitor::VisitVarDef_(var)); + void VisitVarDef_(const DataflowVarNode* var) + PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + ExprVisitor::VisitVarDef_(var)); + + void VisitType(const Type& t) PY_EXPR_VISITOR_DEFAULT(t, f_visit_type, ExprVisitor::VisitType(t)); + void VisitSpan(const Span& span) + PY_EXPR_VISITOR_DEFAULT(span, f_visit_span, ExprVisitor::VisitSpan(span)); + + void VisitAttrs(AttrVisitor* v) {} + static constexpr const char* _type_key = "expr_functor.PyExprVisitor"; + TVM_DECLARE_BASE_OBJECT_INFO(PyExprVisitorNode, Object); + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + PY_EXPR_VISITOR_DISPATCH(ConstantNode, f_visit_constant_); + PY_EXPR_VISITOR_DISPATCH(TupleNode, f_visit_tuple_); + PY_EXPR_VISITOR_DISPATCH(VarNode, f_visit_var_); + PY_EXPR_VISITOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); + PY_EXPR_VISITOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); + PY_EXPR_VISITOR_DISPATCH(RuntimeDepShapeNode, f_visit_runtime_dep_shape_); + PY_EXPR_VISITOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); + PY_EXPR_VISITOR_DISPATCH(GlobalVarNode, f_visit_global_var_); + PY_EXPR_VISITOR_DISPATCH(FunctionNode, f_visit_function_); + PY_EXPR_VISITOR_DISPATCH(CallNode, f_visit_call_); + PY_EXPR_VISITOR_DISPATCH(SeqExprNode, f_visit_seq_expr_); + PY_EXPR_VISITOR_DISPATCH(IfNode, f_visit_if_); + PY_EXPR_VISITOR_DISPATCH(OpNode, f_visit_op_); + PY_EXPR_VISITOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); + return vtable; + } +}; + +TVM_REGISTER_NODE_TYPE(PyExprVisitorNode); + +/*! + * \brief Managed reference to PyExprVisitorNode. + * \sa PyExprVisitorNode + */ +class PyExprVisitor : public ObjectRef { + public: + /*! + * \brief Create a PyExprVisitor with customized methods on the python-side. + * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. + * \param f_visit_constant_ The packed function of `VisitExpr_(const ConstantNode* op)`. + * \param f_visit_tuple_ The packed function of `VisitExpr_(const TupleNode* op)`. + * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. + * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. + * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. + * \param f_visit_runtime_dep_shape_ The packed function of `VisitExpr_(const RuntimeDepShapeNode* + * op)`. + * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. + * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. + * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. + * \param f_visit_call_ The packed function of `VisitExpr_(const CallNode* op)`. + * \param f_visit_seq_expr_ The packed function of `VisitExpr_(const SeqExprNode* op)`. + * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. + * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. + * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. + * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. + * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* + * binding)`. + * \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode* + * binding)`. + * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& + * block)`. + * \param f_visit_binding_block_ The packed function of `VisitBindingBlock_(const + * BindingBlockNode* block)`. + * \param f_visit_dataflow_block_ The packed function of `VisitBindingBlock_(const + * DataflowBlockNode* block)`. + * \param f_visit_var_def The packed function of `VisitVarDef(const Var& var)`. + * \param f_visit_var_def_ The packed function of `VisitVarDef_(const VarNode* var)`. + * \param f_visit_dataflow_var_def_ The packed function of `VisitVarDef_(const DataflowVarNode* + * var)`. + * \param f_visit_type The packed function of `VisitType(const Type& t)`. + * \param f_visit_span The packed function of `VisitSpan(const Span& span)`. + * \return The PyVisitor created. + */ + TVM_DLL static PyExprVisitor MakePyExprVisitor( + PackedFunc f_visit_expr, PackedFunc f_visit_constant_, PackedFunc f_visit_tuple_, + PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, PackedFunc f_visit_shape_expr_, + PackedFunc f_visit_runtime_dep_shape_, PackedFunc f_visit_extern_func_, + PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_, + PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_, + PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_, + PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, + PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, + PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) { + ObjectPtr n = make_object(); + n->f_visit_expr = f_visit_expr; + n->f_visit_binding = f_visit_binding; + n->f_visit_binding_block = f_visit_binding_block; + n->f_visit_var_def = f_visit_var_def; + n->f_visit_type = f_visit_type; + n->f_visit_span = f_visit_span; + n->f_visit_constant_ = f_visit_constant_; + n->f_visit_tuple_ = f_visit_tuple_; + n->f_visit_var_ = f_visit_var_; + n->f_visit_dataflow_var_ = f_visit_dataflow_var_; + n->f_visit_shape_expr_ = f_visit_shape_expr_; + n->f_visit_runtime_dep_shape_ = f_visit_runtime_dep_shape_; + n->f_visit_extern_func_ = f_visit_extern_func_; + n->f_visit_global_var_ = f_visit_global_var_; + n->f_visit_function_ = f_visit_function_; + n->f_visit_call_ = f_visit_call_; + n->f_visit_seq_expr_ = f_visit_seq_expr_; + n->f_visit_if_ = f_visit_if_; + n->f_visit_op_ = f_visit_op_; + n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; + n->f_visit_var_binding_ = f_visit_var_binding_; + n->f_visit_match_shape_ = f_visit_match_shape_; + n->f_visit_binding_block_ = f_visit_binding_block_; + n->f_visit_dataflow_block_ = f_visit_dataflow_block_; + n->f_visit_var_def_ = f_visit_var_def_; + n->f_visit_dataflow_var_def_ = f_visit_dataflow_var_def_; + return PyExprVisitor(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprVisitor, ObjectRef, PyExprVisitorNode); +}; + +/*! + * \brief The abstract interface of ExprMutator. + */ +class PyExprMutatorNode : public Object, public ExprMutator { + private: + using TSelf = PyExprMutatorNode; + using FType = tvm::NodeFunctor; + + public: + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + PackedFunc f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ + PackedFunc f_visit_constant_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ + PackedFunc f_visit_tuple_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + PackedFunc f_visit_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ + PackedFunc f_visit_dataflow_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ + PackedFunc f_visit_shape_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const RuntimeDepShapeNode* op)` function. */ + PackedFunc f_visit_runtime_dep_shape_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ + PackedFunc f_visit_extern_func_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ + PackedFunc f_visit_global_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ + PackedFunc f_visit_function_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + PackedFunc f_visit_call_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ + PackedFunc f_visit_seq_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ + PackedFunc f_visit_if_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ + PackedFunc f_visit_op_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ + PackedFunc f_visit_tuple_getitem_{nullptr}; + /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ + PackedFunc f_visit_binding{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` + * function. */ + PackedFunc f_visit_var_binding_{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)` + * function. */ + PackedFunc f_visit_match_shape_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` + * function. */ + PackedFunc f_visit_binding_block{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` + * function. */ + PackedFunc f_visit_binding_block_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` + * function. */ + PackedFunc f_visit_dataflow_block_{nullptr}; + /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ + PackedFunc f_visit_var_def{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ + PackedFunc f_visit_var_def_{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ + PackedFunc f_visit_dataflow_var_def_{nullptr}; + /*! \brief The packed function to the `VisitType(const Type& t)` function. */ + PackedFunc f_visit_type{nullptr}; + /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ + PackedFunc f_visit_span{nullptr}; + + Expr VisitExpr(const Expr& expr) { + if (f_visit_expr != nullptr) { + return builder_->Normalize(f_visit_expr(expr)); + } else { + static FType vtable = InitVTable(); + return builder_->Normalize(vtable(expr, this)); + } + } + + void VisitBinding(const Binding& binding) { + if (f_visit_binding != nullptr) + f_visit_binding(binding); + else + ExprMutator::VisitBinding(binding); + } + + void VisitBinding_(const VarBindingNode* binding) { + if (f_visit_var_binding_ != nullptr) + f_visit_var_binding_(GetRef(binding)); + else + ExprMutator::VisitBinding_(binding); + } + + void VisitBinding_(const MatchShapeNode* binding) { + if (f_visit_match_shape_ != nullptr) + f_visit_match_shape_(GetRef(binding)); + else + ExprMutator::VisitBinding_(binding); + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) + PY_EXPR_MUTATOR_DEFAULT(block, f_visit_binding_block, ExprMutator::VisitBindingBlock(block), + BindingBlock); + + BindingBlock VisitBindingBlock_(const BindingBlockNode* block) + PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_binding_block_, + ExprMutator::VisitBindingBlock_(block), BindingBlock); + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) + PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + ExprMutator::VisitBindingBlock_(block), BindingBlock); + + Var VisitVarDef(const Var& var) + PY_EXPR_MUTATOR_DEFAULT(var, f_visit_var_def, ExprMutator::VisitVarDef(var), Var); + Var VisitVarDef_(const VarNode* var) PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_var_def_, + ExprMutator::VisitVarDef_(var), Var); + Var VisitVarDef_(const DataflowVarNode* var) + PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + ExprMutator::VisitVarDef_(var), Var); + + Type VisitType(const Type& t) + PY_EXPR_MUTATOR_DEFAULT(t, f_visit_type, ExprMutator::VisitType(t), Type); + + /*! + * \brief Dispatcher for post-order rewrite. + * \param expr The Expr to be rewritten. + * \return The Expr after post-order rewritten. + */ + Expr VisitExprPostOrder(const Expr& expr) { + static FType post_order_vtable = InitPostOrderVTable(); + return post_order_vtable(expr, this); + } + + using ExprMutator::builder_; + using ExprMutator::LookupBinding; + using ExprMutator::var_remap_; + using ExprMutator::VisitWithNewScope; + using ExprMutator::WithShapeAndType; + + void VisitAttrs(AttrVisitor* v) { v->Visit("builder_", &builder_); } + static constexpr const char* _type_key = "expr_functor.PyExprMutator"; + TVM_DECLARE_BASE_OBJECT_INFO(PyExprMutatorNode, Object); + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + PY_EXPR_MUTATOR_DISPATCH(ConstantNode, f_visit_constant_); + PY_EXPR_MUTATOR_DISPATCH(TupleNode, f_visit_tuple_); + PY_EXPR_MUTATOR_DISPATCH(VarNode, f_visit_var_); + PY_EXPR_MUTATOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); + PY_EXPR_MUTATOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); + PY_EXPR_MUTATOR_DISPATCH(RuntimeDepShapeNode, f_visit_runtime_dep_shape_); + PY_EXPR_MUTATOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); + PY_EXPR_MUTATOR_DISPATCH(GlobalVarNode, f_visit_global_var_); + PY_EXPR_MUTATOR_DISPATCH(FunctionNode, f_visit_function_); + PY_EXPR_MUTATOR_DISPATCH(CallNode, f_visit_call_); + PY_EXPR_MUTATOR_DISPATCH(SeqExprNode, f_visit_seq_expr_); + PY_EXPR_MUTATOR_DISPATCH(IfNode, f_visit_if_); + PY_EXPR_MUTATOR_DISPATCH(OpNode, f_visit_op_); + PY_EXPR_MUTATOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); + return vtable; + } + + // initialize the vtable for post order visit. + static FType InitPostOrderVTable() { + FType post_order_vtable; + // Set dispatch + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ConstantNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(TupleNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(VarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataflowVarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ShapeExprNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(RuntimeDepShapeNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ExternFuncNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(GlobalVarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(FunctionNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(CallNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(SeqExprNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(IfNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OpNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(TupleGetItemNode); + return post_order_vtable; + } +}; + +TVM_REGISTER_NODE_TYPE(PyExprMutatorNode); + +/*! + * \brief Managed reference to PyExprMutatorNode. + * \sa PyExprMutatorNode + */ +class PyExprMutator : public ObjectRef { + public: + /*! + * \brief Create a PyExprMutator with customized methods on the python-side. + * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. + * \param f_visit_constant_ The packed function of `VisitExpr_(const ConstantNode* op)`. + * \param f_visit_tuple_ The packed function of `VisitExpr_(const TupleNode* op)`. + * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. + * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. + * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. + * \param f_visit_runtime_dep_shape_ The packed function of `VisitExpr_(const RuntimeDepShapeNode* + * op)`. + * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. + * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. + * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. + * \param f_visit_call_ The packed function of `VisitExpr_(const CallNode* op)`. + * \param f_visit_seq_expr_ The packed function of `VisitExpr_(const SeqExprNode* op)`. + * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. + * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. + * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. + * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. + * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* + * binding)`. + * \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode* + * binding)`. + * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& + * block)`. + * \param f_visit_binding_block_ The packed function of `VisitBindingBlock_(const + * BindingBlockNode* block)`. + * \param f_visit_dataflow_block_ The packed function of `VisitBindingBlock_(const + * DataflowBlockNode* block)`. + * \param f_visit_var_def The packed function of `VisitVarDef(const Var& var)`. + * \param f_visit_var_def_ The packed function of `VisitVarDef_(const VarNode* var)`. + * \param f_visit_dataflow_var_def_ The packed function of `VisitVarDef_(const DataflowVarNode* + * var)`. + * \param f_visit_type The packed function of `VisitType(const Type& t)`. + * \param f_visit_span The packed function of `VisitSpan(const Span& span)`. + * \return The PyExprMutator created. + */ + TVM_DLL static PyExprMutator MakePyExprMutator( + BlockBuilder builder_, PackedFunc f_visit_expr, PackedFunc f_visit_constant_, + PackedFunc f_visit_tuple_, PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, + PackedFunc f_visit_shape_expr_, PackedFunc f_visit_runtime_dep_shape_, + PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, + PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, + PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_, + PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, + PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, + PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) { + ObjectPtr n = make_object(); + n->builder_ = builder_; + n->f_visit_expr = f_visit_expr; + n->f_visit_constant_ = f_visit_constant_; + n->f_visit_tuple_ = f_visit_tuple_; + n->f_visit_var_ = f_visit_var_; + n->f_visit_dataflow_var_ = f_visit_dataflow_var_; + n->f_visit_shape_expr_ = f_visit_shape_expr_; + n->f_visit_runtime_dep_shape_ = f_visit_runtime_dep_shape_; + n->f_visit_extern_func_ = f_visit_extern_func_; + n->f_visit_global_var_ = f_visit_global_var_; + n->f_visit_function_ = f_visit_function_; + n->f_visit_call_ = f_visit_call_; + n->f_visit_seq_expr_ = f_visit_seq_expr_; + n->f_visit_if_ = f_visit_if_; + n->f_visit_op_ = f_visit_op_; + n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; + n->f_visit_binding = f_visit_binding; + n->f_visit_var_binding_ = f_visit_var_binding_; + n->f_visit_match_shape_ = f_visit_match_shape_; + n->f_visit_binding_block = f_visit_binding_block; + n->f_visit_binding_block_ = f_visit_binding_block_; + n->f_visit_dataflow_block_ = f_visit_dataflow_block_; + n->f_visit_var_def = f_visit_var_def; + n->f_visit_var_def_ = f_visit_var_def_; + n->f_visit_dataflow_var_def_ = f_visit_dataflow_var_def_; + n->f_visit_type = f_visit_type; + n->f_visit_span = f_visit_span; + return PyExprMutator(n); + } + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprMutator, ObjectRef, PyExprMutatorNode); +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_EXPR_FUNCTOR_H_ diff --git a/include/tvm/relax/ir_functor.h b/include/tvm/relax/ir_functor.h new file mode 100644 index 0000000000..de4eca59c9 --- /dev/null +++ b/include/tvm/relax/ir_functor.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/ir_functor.h + * \brief A generic functor for working with Relax IR nodes. + * \sa tvm/relax/expr_functor.h for common IR rewriting use-cases. + */ +#ifndef TVM_RELAX_IR_FUNCTOR_H_ +#define TVM_RELAX_IR_FUNCTOR_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +template +class IRFunctor; + +#define IR_FUNCTOR_DEFAULT \ + { return VisitNodeDefault_(op, std::forward(args)...); } + +#define RELAX_IR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitNode_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class IRFunctor { + private: + using TSelf = IRFunctor; + using FType = NodeFunctor; + + public: + using result_type = R; + virtual ~IRFunctor() {} + + R operator()(const ObjectRef& n, Args... args) { + return VisitNode(n, std::forward(args)...); + } + + virtual R VisitNode(const ObjectRef& n, Args... args) { + ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " + "have generated invalid data."; + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + + // IR nodes inherited from Relay + virtual R VisitNode_(const relay::ConstantNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relay::TupleNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relay::GlobalVarNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relay::CallNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relay::IfNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const OpNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relay::TupleGetItemNode* op, Args... args) IR_FUNCTOR_DEFAULT; + + // IR nodes introduced by Relax + virtual R VisitNode_(const relax::VarNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::DataflowVarNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::ShapeExprNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::RuntimeDepShapeNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::MatchShapeNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::VarBindingNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::BindingBlockNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::DataflowBlockNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::SeqExprNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::FunctionNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::ExternFuncNode* op, Args... args) IR_FUNCTOR_DEFAULT; + + virtual R VisitNodeDefault_(const Object* op, Args...) { + LOG(FATAL) << "no default visitor implemented for " << op->GetTypeKey(); + throw; + } + + private: + static FType InitVTable() { + FType vtable; + RELAX_IR_FUNCTOR_DISPATCH(relay::ConstantNode); + RELAX_IR_FUNCTOR_DISPATCH(relay::TupleNode); + RELAX_IR_FUNCTOR_DISPATCH(relay::GlobalVarNode); + RELAX_IR_FUNCTOR_DISPATCH(relay::CallNode); + RELAX_IR_FUNCTOR_DISPATCH(relay::IfNode); + RELAX_IR_FUNCTOR_DISPATCH(OpNode); + RELAX_IR_FUNCTOR_DISPATCH(relay::TupleGetItemNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::VarNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::DataflowVarNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::ShapeExprNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::RuntimeDepShapeNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::MatchShapeNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::VarBindingNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::BindingBlockNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::DataflowBlockNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::SeqExprNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::FunctionNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::ExternFuncNode); + return vtable; + } +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_IR_FUNCTOR_H_ diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h new file mode 100644 index 0000000000..bd05e1814d --- /dev/null +++ b/include/tvm/relax/op_attr_types.h @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/op_attr_types.h + * \brief Data structures that can appear in operator attributes. + */ +#ifndef TVM_RELAX_OP_ATTR_TYPES_H_ +#define TVM_RELAX_OP_ATTR_TYPES_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +using relay::Call; + +/*! + * \brief Infer the output shape for operators. This function will + * be invoked to fill the \p shape_ field of expressions. + * \param call The call node. + * \param diag_ctx The diagnostic context for reporting errors. + * \return The inferred output shape expression. + */ +using FInferShape = + runtime::TypedPackedFunc(const Call& call, DiagnosticContext diag_ctx)>; + +/*! + * \brief Infer the output type for operators. This function will + * be invoked to fill the \p checked_type_ field of expressions. + * \param call The call node. + * \param diag_ctx The diagnostic context for reporting errors. + * \return The inferred output type. + */ +using FInferType = runtime::TypedPackedFunc; + +/*! + * \brief Packed function implementation for operators. The relax operator will be lowered to + * this packed function call during codegen. + */ +using FCallPacked = String; + +/*! \brief Attributes used in unique operator */ +struct UniqueAttrs : public tvm::AttrsNode { + bool sorted; + bool return_inverse; + bool return_counts; + int dim; + TVM_DECLARE_ATTRS(UniqueAttrs, "relax.attrs.UniqueAttrs") { + TVM_ATTR_FIELD(sorted) + .describe( + "Whether to sort the unique elements in ascending order before returning as output.") + .set_default(true); + TVM_ATTR_FIELD(return_inverse) + .describe( + "Whether to return an additional tensor with indices for where elements in the " + "original input ended up in the returned unique list.") + .set_default(false); + TVM_ATTR_FIELD(return_counts) + .describe("Whether to return an additional tensor with counts of each unique elements") + .set_default(false); + TVM_ATTR_FIELD(dim) + .describe( + "The dimension to apply unique. If negative, the unique of the flattened input is " + "returned.") + .set_default(-1); + } +}; // struct UniqueAttrs + +struct PrintAttrs : public tvm::AttrsNode { + std::string format; + TVM_DECLARE_ATTRS(PrintAttrs, "relax.attrs.PrintAttrs") { + TVM_ATTR_FIELD(format) + .describe("Python-style format string to use for displaying the input. Ignored if empty.") + .set_default(""); + } +}; + +struct AssertOpAttrs : public tvm::AttrsNode { + std::string format; + TVM_DECLARE_ATTRS(AssertOpAttrs, "relax.attrs.AssertOpAttrs") { + TVM_ATTR_FIELD(format) + .describe( + "Python-style format string to use for displaying " + "an error message if the assert fails. " + "Ignored if empty.") + .set_default(""); + } +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_OP_ATTR_TYPES_H_ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h new file mode 100644 index 0000000000..a080f52d5f --- /dev/null +++ b/include/tvm/relax/transform.h @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform.h + * \brief Relax specific transformation passes. + */ +#ifndef TVM_RELAX_TRANSFORM_H_ +#define TVM_RELAX_TRANSFORM_H_ + +#include +#include + +namespace tvm { +namespace relax { +namespace transform { + +using Pass = tvm::transform::Pass; +using PassInfo = tvm::transform::PassInfo; +using PassContext = tvm::transform::PassContext; +using Function = tvm::relax::Function; +using DataflowBlock = tvm::relax::DataflowBlock; + +/*! + * \brief Create a function pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the function pass. + * \param name The name of the function pass. + * \param required The list of the passes that the function pass is dependent on. + * + * \return The created function pass. + */ +TVM_DLL Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required, bool traceable = false); + +/*! + * \brief Create a dataflowblock pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the dataflowblock pass. + * \param name The name of the dataflowblock pass. + * \param required The list of the passes that the dataflowblock pass is dependent on. + * + * \return The created dataflowblock pass. + */ +TVM_DLL Pass CreateDataflowBlockPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required, bool traceable = false); + +/*! + * \brief Incorrectly transform the dataflow structure as fail testcases. + * + * \return The Pass. + */ +TVM_DLL Pass FailTestRewrite(); + +/*! + * \brief Perform fused multiply add rewriting in dataflow blocks. + * + * \return The Pass. + */ +TVM_DLL Pass FMARewrite(); + +/*! + * \brief Perform lambda lifting to lift functions from nested into global. + * + * \return The Pass. + */ +TVM_DLL Pass LambdaLift(); + +/*! + * \brief Transform all dataflow structure to non-dataflow version. + * + * \return The Pass. + */ +TVM_DLL Pass ToNonDataflow(); + +/*! + * \brief Perform explicit tensor allocation for call_tir. + * + * \return The Pass. + */ +TVM_DLL Pass CallTIRRewrite(); + +/*! + * \brief Simplify a Relax module by folding var bindings and match shape nodes. + * May include other forms of expression simplification in the future. + * Best used alongside constant folding and eliminating unused bindings. + * + * \return The Pass. + */ +TVM_DLL Pass CanonicalizeBindings(); + +/*! + * \brief Transform Relax IR to normal form: transform AST to A-normal form, and fill the + * checked_type_ and shape_ of expressions. + * + * \return The Pass. + */ +TVM_DLL Pass Normalize(); + +/*! + * \brief Bind params of function of the module to constant tensors. + * + * \param func_name The name of the function to bind parameters. + * \param params The parameters to bind. + * + * \return The Pass. + */ +TVM_DLL Pass BindParams(String name, Map params); + +/*! + * \brief Fold constant expressions. + * + * \return The Pass. + */ +TVM_DLL Pass FoldConstant(); + +/*! + * \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps. + * \note It is an auto-detect pass for "unscheduled prim_funcs", the op_pattern will be + * "opaque" of we can't detect it. Users can manually annotate the attr `op_pattern` + * to prim_func. + * \return The Pass. + */ +TVM_DLL Pass AnnotateTIROpPattern(); + +/*! + * \brief This pass groups bindings in a dataflow block of Relax functions and generates a new + * grouped Relax function for each group, according to the fusion algorithm described in the pass + * implementation. By grouping bindings into new Relax functions, we substitute the bindings in the + * function being manipulated into function calls to the new grouped function. + * + * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + * \param fuse_opt_level The level of fuse optimization. + * -1 indicates that the level will be inferred from pass context. + * \return The Pass. + */ +TVM_DLL Pass FuseOps(int fuse_opt_level = -1); + +/*! + * \brief Fuse relax sub-function into a larger TIR function if possible. + this pass works together with FuseOps to perform operator fusion. + + * \return The Pass. + */ +TVM_DLL Pass FuseTIR(); + +/*! + * \brief Remove unused global relax functions in a IRModule. + * \param entry_functions list of entry functions + * \return The Pass. + */ +TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); + +/*! + * \brief Run codegen. + * \param target_codegens list of codegens + * \param entry_functions list of entry functions + * \return The Pass. + */ +TVM_DLL Pass RunCodegen(Optional> target_codegens, + Array entry_functions); + +} // namespace transform +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_TRANSFORM_H_ diff --git a/include/tvm/relax/tuning_api.h b/include/tvm/relax/tuning_api.h new file mode 100644 index 0000000000..d302b632b5 --- /dev/null +++ b/include/tvm/relax/tuning_api.h @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/tuning_api.h + * \brief Relax Tuning Pass APIs. + */ +#ifndef TVM_RELAX_TUNING_API_H_ +#define TVM_RELAX_TUNING_API_H_ +#include +#include +#include + +#include +namespace tvm { +namespace relax { + +/*! \brief Helper function to unpack arguments in the array as parameters for the given packed + * function. */ +TVM_ALWAYS_INLINE TVMRetValue CallPackedWithArgsInArray(const runtime::PackedFunc f, + const Array& args) { + size_t num_args = args.size(); + std::vector values(num_args); + std::vector codes(num_args); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + const ObjectRef* ptr = args.template as()->begin(); + for (size_t i = 0; i < num_args; ++i) { + setter(i, *(ptr + i)); + } + + TVMRetValue rv; + f.CallPacked(TVMArgs(values.data(), codes.data(), num_args), &rv); + return rv; +} + +/*! \brief Choice manages a set of keys for transformation and constraint functions. */ +class ChoiceNode : public runtime::Object { + public: + /*! \brief ffi key for transformation function. */ + String transform_func_key; + /*! \brief ffi key for constraint function. */ + String constr_func_key; + Array transform_func_args; + Array constr_func_args; + + /*! \brief The default destructor. */ + virtual ~ChoiceNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("transform_func_key", &transform_func_key); + v->Visit("transform_func_args", &transform_func_args); + v->Visit("constr_func_key", &constr_func_key); + v->Visit("constr_func_args", &constr_func_args); + } + + /*! \brief Getter for constr_func. */ + const runtime::PackedFunc GetConstrFunc() { + const auto* constr_func = tvm::runtime::Registry::Get(constr_func_key); + ICHECK(constr_func != nullptr) << "constr_func_key is not registered: " << constr_func_key; + return *constr_func; + } + + /*! \brief Getter for transform_func. */ + const runtime::PackedFunc GetTransformFunc() { + auto* transform_func = tvm::runtime::Registry::Get(transform_func_key); + ICHECK(transform_func != nullptr) + << "transform_func_key is not registered: " << transform_func_key; + return *transform_func; + } + + /*! \brief Perform constr_func. */ + bool CheckConstr(const IRModule& mod) { + Array args(constr_func_args); + args.insert(args.begin(), mod); + return CallPackedWithArgsInArray(GetConstrFunc(), args); + } + + /*! \brief Perform transform_func. */ + IRModule ApplyTransformFunc(IRModule mod) { + // Apply transformation when constraint is satisfied. + if (CheckConstr(mod)) { + Array args(transform_func_args); + args.insert(args.begin(), GetRef(mod.CopyOnWrite())); + return CallPackedWithArgsInArray(GetTransformFunc(), args); + } + return mod; + } + + /*! + * \brief Serialize Choice as a JSON-style object + * \return The JSON-style object + */ + ObjectRef AsJSON() const; + + static constexpr const char* _type_key = "relax.tuning_api.Choice"; + TVM_DECLARE_BASE_OBJECT_INFO(ChoiceNode, Object); +}; + +/*! \brief Managed reference to ChoiceNode */ +class Choice : public runtime::ObjectRef { + public: + TVM_DLL explicit Choice(String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args); + /*! \brief Deserialize JSON-style object into Choice */ + TVM_DLL static Choice FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Choice, ObjectRef, ChoiceNode); +}; + +/*! \brief Knob manages a set of valid choices for an optimization. */ +class KnobNode : public runtime::Object { + public: + /*! \brief Name of the knob. */ + String name; + /*! \brief Decision space. */ + Map choices; + + /*! \brief The default destructor. */ + virtual ~KnobNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("choices", &choices); + } + + /*! \brief Check if a decision is valid. */ + bool IsValidDecision(String decision) { return choices.count(decision) > 0; } + + /*! \brief Apply decision if the constraint is satisfied. + Otherwise, return the original IRModule. + */ + IRModule Apply(IRModule mod, String decision) { + ICHECK(IsValidDecision(decision)) << "Invalid choice for this knob: " << decision; + return choices[decision]->ApplyTransformFunc(mod); + } + + /*! + * \brief Serialize Knob as a JSON-style object + * \return The JSON-style object + */ + ObjectRef AsJSON() const; + + static constexpr const char* _type_key = "relax.tuning_api.Knob"; + TVM_DECLARE_BASE_OBJECT_INFO(KnobNode, Object); +}; + +/*! \brief Managed reference to KnobNode */ +class Knob : public runtime::ObjectRef { + public: + TVM_DLL explicit Knob(String name, Map choices); + /*! \brief Deserialize JSON-style object into Knob */ + TVM_DLL static Knob FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Knob, ObjectRef, KnobNode); +}; + +/*! \brief Trace manages history of optimization decisions. */ +class TraceNode : public runtime::Object { + public: + /*! \brief Input IRModule. */ + IRModule in_mod; + /*! \brief Output IRModule. */ + mutable IRModule out_mod; + // TODO(sunggg): can we move knobs and decisions into private? + /*! \brief Knobs that are applied so far. */ + Array knobs; + /*! \brief Decisions made for the knobs. */ + Array decisions; + /*! \brief Performance of out_mod. */ + mutable double perf = -1; + /*! \brief Length of the decision history. */ + mutable int size = 0; + /*! \brief The default destructor. */ + virtual ~TraceNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("in_mod", &in_mod); + v->Visit("out_mod", &out_mod); + v->Visit("knobs", &knobs); + v->Visit("decisions", &decisions); + v->Visit("perf", &perf); + v->Visit("size", &size); + } + + /*! \brief Verify current decision history. */ + bool Verify() const { + if (knobs.size() != decisions.size()) return false; + int n = knobs.size(); + for (int i = 0; i < n; i++) { + if (!knobs[i]->IsValidDecision(decisions[i])) return false; + } + return true; + } + + /*! \brief Add a knob and its decision to the current trace. */ + IRModule Add(Knob knob, String decision) { + out_mod = knob->Apply(out_mod, decision); + knobs.push_back(knob); + decisions.push_back(decision); + // perf number should be initialized after new decision is applied. + perf = -1; + // increment history size. + size++; + return out_mod; + } + + /*! + * \brief Serialize Trace as a JSON-style object + * \param include_in_mod Boolean config to include input IRModule in the output. + * \return The JSON-style object + */ + ObjectRef AsJSON(bool include_in_mod = true) const; + + /*! \brief Set the performance. */ + void SetPerf(double _perf) { perf = _perf; } + /*! \brief Set output module. */ + void SetOutMod(IRModule mod_) { out_mod = mod_; } + + static constexpr const char* _type_key = "relax.tuning_api.Trace"; + TVM_DECLARE_BASE_OBJECT_INFO(TraceNode, Object); +}; + +/*! \brief Managed reference to TraceNode */ +class Trace : public runtime::ObjectRef { + public: + /*! \brief Default constructor. Creating an empty trace. */ + Trace(); + /*! + * \brief Constructor. Creating a trace from existing knobs and their decisions + * \param in_mod Input IRModule + * \param knobs The knobs used + * \param decisions The decisions made in sampling + */ + TVM_DLL explicit Trace(IRModule in_mod, Array knobs, Array decisions); + /*! \brief Deserialize JSON-style object into Trace */ + TVM_DLL static Trace FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, ObjectRef, TraceNode); +}; + +/*! \brief The class of tuning records. */ +class TuningRecordNode : public runtime::Object { + public: + /*! \brief The trace tuned. */ + Trace trace; + /*! \brief The measurement record in seconds. */ + Optional> run_secs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("trace", &trace); + v->Visit("run_secs", &run_secs); + } + + static constexpr const char* _type_key = "relax.tuning_api.TuningRecord"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object); + + /*! + * \brief Export the tuning record to a JSON string. + * \param include_irmod Boolean config to include IRModules in the output. + * \return JSON object + */ + ObjectRef AsJSON(bool include_irmod = false) const; +}; + +/*! + * \brief The managed reference of TuningRecordNode. + * \sa TuningRecordNode + */ +class TuningRecord : public runtime::ObjectRef { + public: + /*! + \brief Constructor of a tuning record. + \param trace The trace of the tuning record. + \param run_secs The running time of the tuning record. + */ + TVM_DLL explicit TuningRecord(Trace trace, Optional> run_secs); + /*! + * \brief Create a tuning record from a json object. + * \param json_obj The json object. + * \param workload The workload. + * \return The tuning record created. + */ + TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode); +}; + +/* \brief The abstract interface of database. */ +class DatabaseNode : public runtime::Object { + public: + /*! \brief Default destructor */ + virtual ~DatabaseNode() = default; + /*! + * \brief Check if the database has the given workload. + * \param mod The IRModule to be searched for. + * \return Whether the database has the given workload. + */ + virtual bool HasWorkload(const IRModule& mod) = 0; + /*! + * \brief Check if the database has a measurement record for the given workload and target pair. + * \param workload The workload to be searched for. + * \param target The target to be searched for. + * \return Whether the database has the measurement record for given workload and target pair. + */ + virtual bool HasMeasurementRecord(const meta_schedule::Workload& workload, + const Target& target) = 0; + /*! + * \brief Check if the database has a tuning record for the given workload and target pair. + * \param workload The workload to be searched for. + * \param target The target to be searched for. + * \return Whether the database has the tuning record for the given workload and target pair. + */ + virtual bool HasTuningRecord(const meta_schedule::Workload& workload, const Target& target) = 0; + /*! + * \brief Look up or add workload to the database if missing. + * \param mod The IRModule to be searched for or added. + * \return The workload corresponding to the given IRModule. + */ + virtual meta_schedule::Workload CommitWorkload(const IRModule& mod) = 0; + /*! + * \brief Add a measurement record for a given pair of target and workload to the database. + * \param workload Workload to be searched for. + * \param target Target to be searched for. + * \param record Measurement record to be added. + */ + virtual void CommitMeasurementRecord(const meta_schedule::Workload& workload, + const Target& target, const Array& record) = 0; + /*! + * \brief Add a tuning record for a given pair of target and workload to the database. + * \param workload Workload to be searched for. + * \param target Target to be searched for. + * \param record Tuning record to be added. + */ + virtual void CommitTuningRecord(const meta_schedule::Workload& workload, const Target& target, + const TuningRecord& record) = 0; + /*! + * \brief Get the top K tuning records of given workload and target from the database. + * \param workload The workload to be searched for. + * \param target Target to be searched for. + * \param top_k The number of top records to be returned. + * \return An array of top K tuning records for the given workload. + */ + virtual Array GetTopK(const meta_schedule::Workload& workload, const Target& target, + int top_k) = 0; + /*! + * \brief Get the measurement record of given workload and target from the database. + * \param workload The workload to be searched for. + * \param target Target to be searched for. + * \return Measurement. + */ + virtual Array GetMeasurementRecord(const meta_schedule::Workload& workload, + const Target target) = 0; + + static constexpr const char* _type_key = "relax.tuning_api.Database"; + TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object); +}; + +/*! + * \brief Managed reference to DatabaseNode. + * \sa DatabaseNode + */ +class Database : public runtime::ObjectRef { + public: + /*! + * \brief Create a default database that uses JSON file for tuning records. + * \param path_workload The path to the workload table. + * \param path_tuning_record The path to the tuning record table. + * \param path_measurement_record The path to the measurement_record table. + * \param allow_missing Whether to create new file when the given path is not found. + */ + TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record, + String path_measurement_record, bool allow_missing); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode); +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_TUNING_API_H_ diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h new file mode 100644 index 0000000000..ce6621f7aa --- /dev/null +++ b/include/tvm/relax/type.h @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/type.h + * \brief Relax typed AST nodes. + */ +#ifndef TVM_RELAX_TYPE_H_ +#define TVM_RELAX_TYPE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +class ShapeTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.ShapeType"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode); +}; + +class ShapeType : public Type { + public: + TVM_DLL ShapeType(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode); +}; + +class ObjectTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const ObjectTypeNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.ObjectType"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjectTypeNode, TypeNode); +}; + +class ObjectType : public Type { + public: + TVM_DLL ObjectType(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectType, Type, ObjectTypeNode); +}; + +class DynTensorTypeNode : public BaseTensorTypeNode { + public: + /*! + * \brief The number of dimensions of the tensor, use -1 to denote tensor with unknwon number of + * dimensions. + */ + int ndim; + /*! \brief The content data type, use void to denote the dtype is unknown. */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("ndim", &ndim); + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + bool SEqualReduce(const DynTensorTypeNode* other, SEqualReducer equal) const { + return equal(ndim, other->ndim) && equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(ndim); + hash_reduce(dtype); + } + + inline bool IsUnknownNdim() const { return ndim == -1; } + + inline bool IsUnknownDtype() const { return dtype.is_void(); } + + static constexpr const char* _type_key = "relax.DynTensorType"; + TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode); +}; + +/*! + * \brief Managed reference to DynTensorTypeNode. + * \sa DynTensorTypeNode. + */ +class DynTensorType : public Type { + public: + /*! + * \brief Constructor. + * \param shape The shape of the tensor. + * \param dtype The runtime dtype of the tensor's elements. + */ + TVM_DLL DynTensorType(int ndim, DataType dtype, Span span = Span()); + + /*! + * \brief Create a DynTensorType with unknown ndim. + */ + TVM_DLL static DynTensorType CreateUnknownNDim(DataType dtype, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode); +}; + +class DimTypeNode : public TypeNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } + + bool SEqualReduce(const DimTypeNode* other, SEqualReducer equal) const { return true; } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.DimType"; + TVM_DECLARE_FINAL_OBJECT_INFO(DimTypeNode, TypeNode); +}; + +class DimType : public Type { + public: + TVM_DLL DimType(Span span = Span()); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DimType, Type, DimTypeNode); +}; + +/*! + * \brief Check the subtype relationship between base and derived. + * \param base The base type. + * \param derived The derived type. + * \return If \p derived is a subtype of \p base or if both are the same type, returns true. + * Otherwise returns false. + */ +bool IsBaseOf(const Type& base, const Type& derived); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_TYPE_H_ diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h new file mode 100644 index 0000000000..6bfd0d0daa --- /dev/null +++ b/include/tvm/relax/utils.h @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/utils.h + * \brief Utility classes and functions for working with the Relax IR. + */ +#ifndef TVM_RELAX_UTILS_H_ +#define TVM_RELAX_UTILS_H_ + +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Utility data structure for generating unique names for IR construction. + */ +class NameTable { + public: + /*! + * \brief Generate a unique name with a specified prefix. + * \param prefix The name prefix. + * \return The generated name. + */ + inline std::string GetUniqueName(std::string prefix) { + std::replace(prefix.begin(), prefix.end(), '.', '_'); + std::string unique_prefix = prefix; + auto it = alloc_map_.find(prefix); + if (it != alloc_map_.end()) { + while (alloc_map_.count(unique_prefix = prefix + std::to_string(++it->second)) > 0) { + } + } + alloc_map_[unique_prefix] = 0; + return unique_prefix; + } + + NameTable() = default; + + template + explicit NameTable(Iter begin, Iter end, Lambda f) { + // static_assert is more reader-friendly than SFINAE when template specialization is not needed. + static_assert(std::is_convertible::value, + "Lambda f must has a signature of [?](*it) -> string {}"); + for (auto it = begin; it != end; ++it) { + const std::string& name = f(*it); + const size_t idx_last_first_num = std::distance( + std::find_if(name.rbegin(), name.rend(), [](char c) { return !std::isdigit(c); }), + name.rend()); + // name = {O = others}{D = consecutive digits} + // let O -> prefix; + std::string prefix = name.substr(0, idx_last_first_num); + ICHECK(prefix.size() > 0 && std::isalpha(prefix[0])) << "Invalid variable name: " << name; + if (0 == alloc_map_.count(prefix)) alloc_map_[prefix] = 0; + if (idx_last_first_num < name.size()) { // has some digits. + // let D's nearest natural number -> idx; + // note: stoul("000123") = 123; + alloc_map_[prefix] = + std::max(alloc_map_[prefix], std::stoi(name.substr(idx_last_first_num))); + } + } + } + + template + explicit NameTable(Iter begin, Iter end) + : NameTable(begin, end, [](const decltype(*begin)& v) { return v; }) {} + + private: + std::unordered_map alloc_map_; +}; + +/*! + * \brief Bind the variables to a Relax expression. This is a helper + * function usually called by other pass functions to help optimizations. + * If any free variables are introduced into a function, those are added + * to the function parameters. + * Additionally this may change the order of parameters if you map a variable + * to a variable. + * + * \param expr The input expression. + * \param binds The variable to expression map that will be used to help the + * binding. + * + * \return The updated expression. + */ +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); + +/*! + * \brief Check if the given type is a boolean scalar type (tensor of rank 0 with a boolean dtype). + * + * \param ty The input type. + * \param permit_unknown_rank If true, it will permit the input type to have unknown rank + * (ndim of -1), which will require a dynamic check. + * \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype + * (namely, void), which will require a dynamic check. + * + * \return True iff the input type is a boolean scalar type (or, depending on options, has unknown + * rank or dtype) + */ +TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, + bool permit_unknown_dtype = true); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_UTILS_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index bd094a7f69..c2b1becbcb 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -85,6 +85,7 @@ class ConstantNode : public ExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); + v->Visit("shape_", &shape_); } bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { @@ -131,6 +132,7 @@ class TupleNode : public ExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); + v->Visit("shape_", &shape_); } bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const { @@ -329,6 +331,7 @@ class CallNode : public ExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); + v->Visit("shape_", &shape_); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index cdea8e8e3c..f13322eed3 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -60,7 +60,7 @@ using Sequential = tvm::transform::Sequential; */ TVM_DLL Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! \brief Remove let-bound expressions which do not effect the program result. * diff --git a/include/tvm/runtime/relax_vm/bytecode.h b/include/tvm/runtime/relax_vm/bytecode.h new file mode 100644 index 0000000000..50735f6fb2 --- /dev/null +++ b/include/tvm/runtime/relax_vm/bytecode.h @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/bytecode.h + * \brief The bytecode for the virtual machine. + */ +#ifndef TVM_RUNTIME_RELAX_VM_BYTECODE_H_ +#define TVM_RUNTIME_RELAX_VM_BYTECODE_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief The storage type for the bytecode in the VM. + */ +using ExecWord = int64_t; + +/*! \brief A register name. */ +using RegName = ExecWord; + +/*! + * \brief An alias for the integer type used ubiquitously in the VM. + */ +using Index = ExecWord; + +/*! + * \brief An enumeration of Relax's opcodes. + * + * The opcode is used to implement instruction + * as a tagged union. + */ +enum class Opcode { + Call = 1U, + Ret = 2U, + Goto = 3U, + If = 4U, +}; + +/*! \brief A single virtual machine instruction. + * + * The representation of the instruction is as + * a tagged union. + * + * The first field represents which instruction, + * and by extension which field of the union + * is active. + */ +struct Instruction { + /*! \brief Random magic number that represents void argument. */ + static constexpr RegName kVoidArg = 0x00EC66FE0321975A; + /*! \brief Random magic number that represents the VM. */ + static constexpr RegName kVMRegister = 0x008D14FA4379015C; + /*! + * \brief The kind of instruction's argument. + */ + enum ArgKind { + kRegister = 0, + kImmediate = 1, + kConstIdx = 2, + }; + /*! + * \brief The auxiliary data structure for instruction argument. + */ + struct Arg { + /*! \brief The number of bit for storing value. */ + static constexpr ExecWord kValueBit = sizeof(ExecWord) * 8 - 8; + /*! \brief The bit mask of the value part. */ + static constexpr ExecWord kValueMask = (static_cast(1) << kValueBit) - 1; + /*! \brief Construct a void argument. */ + Arg() : data(Instruction::kVoidArg) {} + /*! \brief Construct from the data. */ + explicit Arg(ExecWord data) : data(data) {} + /*! \brief Construct from the kind and value. */ + Arg(ArgKind kind, Index value) { + // TODO(ziheng): check value? + this->data = (static_cast(kind) << kValueBit) | (value & kValueMask); + } + /*! + * \brief Get the kind of argument.. + * \return The kind of argument. + */ + ArgKind kind() const { + uint8_t kind = (data >> kValueBit) & 0xFF; + return Instruction::ArgKind(kind); + } + /*! + * \brief Get the value of argument.. + * \return The value of argument. + */ + ExecWord value() const { return data & ((static_cast(1) << kValueBit) - 1); } + /*! \brief The underlying stored data. */ + ExecWord data; + }; + /*! \brief The instruction opcode. */ + Opcode op; + /*! \brief The destination register. */ + RegName dst; + union { + struct /* Call */ { + /*! \brief The index into the packed function table. */ + Index func_idx; + /*! \brief The number of arguments to the packed function. */ + Index num_args; + /*! \brief The arguments of the packed function. */ + Arg* args; + }; + struct /* Ret */ { + /*! \brief The return result. */ + RegName result; + }; + struct /* Goto */ { + /*! \brief The jump offset. */ + Index pc_offset; + }; + struct /* If */ { + /*! \brief The register containing the cond value. */ + RegName cond; + /*! \brief The program counter offset for the false branch. */ + Index false_offset; + }; + }; + /*! + * \brief Construct a Call instruction. + * \param func_idx The index of the function to call. + * \param num_args The number of arguments. + * \param args The input arguments. + * \param dst The destination register. + * \return The call instruction. + */ + static Instruction Call(Index func_idx, Index num_args, Arg* args, RegName dst); + /*! + * \brief Construct a return instruction. + * \param result The register containing the return value. + * \return The return instruction. + */ + static Instruction Ret(RegName result); + /*! + * \brief Construct a goto instruction. + * \param pc_offset The register containing the jump offset. + * \return The goto instruction. + */ + static Instruction Goto(RegName pc_offset); + /*! + * \brief Construct an If instruction. + * \param cond The register containing the cond value. + * \param false_offset The program counter offset for the false branch. + * \return The If instruction. + */ + static Instruction If(RegName cond, Index false_offset); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_BYTECODE_H_ diff --git a/include/tvm/runtime/relax_vm/executable.h b/include/tvm/runtime/relax_vm/executable.h new file mode 100644 index 0000000000..9856598118 --- /dev/null +++ b/include/tvm/runtime/relax_vm/executable.h @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/executable.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ +#define TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ + +#include +#include +#include + +#include +#include +#include + +#include "./bytecode.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief An object representing a vm closure. + */ +class VMClosureObj : public ClosureObj { + public: + /*! + * \brief The function name. The function could be any + * function object that is compatible to the VM runtime. + */ + String func_name; + /*! \brief The free variables of the closure. */ + Array free_vars; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.vm.Closure"; + TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj); +}; + +/*! \brief reference to closure. */ +class VMClosure : public Closure { + public: + VMClosure(String func_name, Array free_vars); + TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj); +}; + +/*! + * \brief A representation of a Relax function in the VM. + * + * Contains metadata about the compiled function, as + * well as the compiled VM instructions. + */ +struct VMFunction { + /*! \brief The function's name. */ + std::string name; + /*! \brief The start instruction index of the function. */ + Index start_instr; + /*! \brief The number of arguments of the function. */ + Index num_args; + /*! \brief The register file size of the function. */ + Index register_file_size; + /*! \brief The function parameter names.*/ + std::vector param_names; +}; + +/*! + * \brief The executable emitted by the VM compiler. + * + * The executable contains information (e.g. data in different memory regions) + * to run in a virtual machine. + */ +class Executable : public runtime::ModuleNode { + public: + /*! + * \brief Get a PackedFunc from the executable module. + * \param name the name of the function. + * \param sptr_to_self The shared_ptr that points to this module node. + * \return PackedFunc or nullptr when it is not available. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + /*! + * \brief Print the detailed statistics of the given code, i.e. number of + * globals and constants, etc. + * \return The statistics represented by a string. + */ + std::string Stats() const; + /*! + * \brief Get the i-th instruction from the executable. + * \param i The index of the instruction to be fetched. + * \return The instruction. + */ + Instruction GetInstruction(Index i) const; + /*! + * \brief Set j-th byte data of i-th instruction to val. + * \param i The index of the instruction to be updated. + * \param j The index of the byte data of the instruction to be updated. + * \param val The value to be set + */ + void SetInstructionData(Index i, Index j, ExecWord val); + /*! + * \brief Print the instructions as text format. + * \return The text format of the instructions. + */ + String AsText() const; + /*! + * \brief Print the instructions as python program. + * \return The python program of the instructions, represented by a string. + */ + String AsPython() const; + /*! + * \brief Write the Executable to the binary stream in serialized form. + * \param stream The binary stream to save the executable to. + */ + void SaveToBinary(dmlc::Stream* stream) final; + /*! + * \brief Load Executable from the binary stream in serialized form. + * \param stream The binary stream that load the executable from. + * \return The loaded executable, in the form of a `runtime::Module`. + */ + static Module LoadFromBinary(void* stream); + /*! + * \brief Write the Executable to the provided path as a file containing its serialized content. + * \param file_name The name of the file to write the serialized data to. + * \param format The target format of the saved file. + */ + void SaveToFile(const std::string& file_name, const std::string& format) final; + /*! + * \brief Load Executable from the file. + * \param file_name The path of the file that load the executable from. + * \return The loaded executable, in the form of a `runtime::Module`. + */ + static Module LoadFromFile(const std::string& file_name); + + /*! \brief The virtual machine's function table. */ + std::vector global_funcs; + /*! \brief A map from globals (as strings) to their index in the function map. */ + std::unordered_map global_map; + /*! \brief The global constant pool. */ + std::vector constants; + /*! \brief The name of packed functions. */ + std::vector func_names; + /*! + * \brief A mapping from the packed function (as string) to the index that + * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object. + */ + std::unordered_map func2idx; + /*! \brief The offset of instruction. */ + std::vector instr_offset; + /*! \brief The byte data of instruction. */ + std::vector instr_data; + + virtual ~Executable() {} + + const char* type_key() const final { return "relax.Executable"; } + + private: + /*! + * \brief Save the globals. + * \param strm The input stream. + */ + void SaveGlobalSection(dmlc::Stream* strm); + /*! + * \brief Save the constant pool. + * \param strm The input stream. + */ + void SaveConstantSection(dmlc::Stream* strm); + /*! + * \brief Save the instructions. + * \param strm The input stream. + */ + void SaveCodeSection(dmlc::Stream* strm); + /*! + * \brief Save the packed functions. + * \param strm The input stream. + */ + void SavePackedFuncNames(dmlc::Stream* strm); + /*! + * \brief Load the globals. + * \param strm The input stream. + */ + void LoadGlobalSection(dmlc::Stream* strm); + /*! + * \brief Load the constant pool. + * \param strm The input stream. + */ + void LoadConstantSection(dmlc::Stream* strm); + /*! + * \brief Load the instructions. + * \param strm The input stream. + */ + void LoadCodeSection(dmlc::Stream* strm); + /*! + * \brief Save the packed functions. + * \param strm The input stream. + */ + void LoadPackedFuncNames(dmlc::Stream* strm); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ diff --git a/include/tvm/runtime/relax_vm/memory_manager.h b/include/tvm/runtime/relax_vm/memory_manager.h new file mode 100644 index 0000000000..e5ae8cfcfb --- /dev/null +++ b/include/tvm/runtime/relax_vm/memory_manager.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/memory_manager.h + * \brief Abstract device memory management API + */ +#ifndef TVM_RUNTIME_RELAX_VM_MEMORY_MANAGER_H_ +#define TVM_RUNTIME_RELAX_VM_MEMORY_MANAGER_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +struct Buffer { + /*! \brief The pointer to the allocated block of memory. */ + void* data{nullptr}; + /*! \brief The size of the block. */ + size_t size{0}; + /*! \brief The device of the allocated buffers. */ + Device device; +}; + +enum AllocatorType { + kNaive = 1, + kPooled, +}; + +class Allocator { + public: + explicit Allocator(AllocatorType type) : type_(type) {} + virtual ~Allocator() = default; + /*! \brief Allocate an empty NDArray using from the allocator. + * \param shape The shape of the NDArray. + * \param dtype The datatype of the NDArray. + * \param dev The device where the array is allocated. + * \return The empty NDArray. + */ + runtime::NDArray Empty(std::vector shape, DLDataType dtype, Device dev); + /*! \brief Return the allocator type. */ + inline AllocatorType type() const { return type_; } + /*! \brief Allocate a buffer given a size, alignment and type. + * \param nbytes The size of the buffer. + * \param alignment The alignment of the buffer. + * \param type_hint A type hint to the allocator. + * \return A sized allocation in the form of a buffer. + */ + virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) = 0; + /*! \brief Free a buffer allocated by the allocator. + * \param buffer The buffer to free. + */ + virtual void Free(const Buffer& buffer) = 0; + + private: + AllocatorType type_; +}; + +class MemoryManager { + public: + static MemoryManager* Global(); + /*! + * \brief Get or create an allocator given the device and allocator type. + * \param dev The TVM device + * \param type The allocator type + * \return The memory allocator. + */ + static Allocator* GetOrCreateAllocator(Device dev, AllocatorType type); + /*! + * \brief Get an allocator given the device. + * \param dev The TVM device + * \return The memory allocator. + */ + static Allocator* GetAllocator(Device dev); + + private: + MemoryManager() {} + + private: + std::mutex mutex_; + std::unordered_map> allocators_; +}; + +/*! \brief An object representing a storage allocation. */ +class StorageObj : public Object { + public: + /*! \brief The index into the VM function table. */ + Buffer buffer; + + /*! \brief Allocate an NDArray from a given piece of storage. */ + runtime::NDArray AllocNDArray(uint64_t offset, ShapeTuple shape, DLDataType dtype); + + /*! \brief The deleter for an NDArray when allocated from underlying storage. */ + static void Deleter(Object* ptr); + + ~StorageObj() { + auto alloc = MemoryManager::Global()->GetAllocator(buffer.device); + alloc->Free(buffer); + } + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.Storage"; + TVM_DECLARE_FINAL_OBJECT_INFO(StorageObj, Object); +}; + +/*! \brief reference to storage. */ +class Storage : public ObjectRef { + public: + explicit Storage(Buffer buffer); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Storage, ObjectRef, StorageObj); +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_MEMORY_MANAGER_H_ diff --git a/include/tvm/runtime/relax_vm/vm.h b/include/tvm/runtime/relax_vm/vm.h new file mode 100644 index 0000000000..a6de95cd45 --- /dev/null +++ b/include/tvm/runtime/relax_vm/vm.h @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/vm.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_VM_H_ +#define TVM_RUNTIME_RELAX_VM_VM_H_ + +#include +#include +#include + +#include "./bytecode.h" +#include "./executable.h" +#include "./memory_manager.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! + * \brief The register type. + */ +using RegType = TVMRetValue; + +/*! + * \brief A representation of a stack frame. + * + * A stack frame is a record containing the information needed + * to restore the caller's virtual machine state after returning + * from a function call. + */ +struct VMFrame { + /*! \brief The return program counter. */ + Index return_pc; + /*! \brief Statically allocated space for objects */ + std::vector register_file; + /*! \brief Register in caller's frame to put return value */ + RegName caller_return_register; + // The following fields are used for PackedFunc call within + // a single function scope. The space is reused across multiple + // packed func calls to increase cache locality and avoid re-allocation + /*! \brief Temporary argument value stack for packed func call. */ + std::vector call_arg_values; + /*! \brief Temporary argument tcode stack for packed func call. */ + std::vector call_arg_tcodes; + + VMFrame(Index pc, Index register_file_size) + : return_pc(pc), register_file(register_file_size), caller_return_register(0) {} +}; + +/*! + * \brief The virtual machine. + * + * The virtual machine contains all the current execution state, + * as well as the executable. + * + * The goal is to have a single self-contained object, + * enabling one to easily pass around VMs, execute them on + * multiple threads, or serialize them to disk or over the + * wire. + */ +class VirtualMachine : public runtime::ModuleNode { + public: + /*! + * \brief Initialize the virtual machine for a set of devices. + * \param devices The set of TVM devices. + * \param alloc_types The allocator types for each device. + */ + void Init(const std::vector& devices, const std::vector& alloc_types); + /*! + * \brief Load the executable for the virtual machine. + * \param exec The executable. + */ + void LoadExecutable(ObjectPtr exec); + /*! + * \brief Get a PackedFunc from module. + * + * The PackedFunc may not be fully initialized, + * there might still be first time running overhead when + * executing the function on certain devices. + * For benchmarking, use prepare to eliminate + * + * \param name the name of the function. + * \param sptr_to_self The shared_ptr that points to this module node. + * + * \return PackedFunc(nullptr) when it is not available. + * + * \note The function will always remain valid. + * If the function needs resource from the module(e.g. late linking), + * it should capture sptr_to_self. + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + ~VirtualMachine() {} + + const char* type_key() const final { return "relax.VirtualMachine"; } + + /*! \brief The kernel library. */ + Optional lib; + /*! \brief The memory allocators. */ + std::vector allocators; + /*! \brief Runtime physical device list. */ + std::vector devices; + + protected: + /*! + * \brief Push a call frame onto the call stack. + * \param ret_pc The program counter to return to. + * \param vm_func The function to be pushed to the call stack. + */ + void PushFrame(Index ret_pc, const VMFunction& vm_func); + /*! + * \brief Pop a frame off the call stack. + */ + void PopFrame(); + /*! + * \brief Write to a VM register. + * \param frame current vm frame. + * \param reg The register to write to. + * \param obj The object to write to. + */ + inline void WriteRegister(VMFrame* frame, RegName reg, const RegType& obj); + /*! + * \brief Read a VM register. + * \param frame current vm frame. + * \param reg The register to read from. + * \return The value of the register. + */ + inline RegType ReadRegister(VMFrame* frame, RegName reg) const; + /*! + * \brief Prepare function table so that func_table_[func_index] is populated. + * \param func_index The function index. + */ + inline void PrepareFuncTable(Index func_index); + /*! + * \brief Invoke a VM function. + * \param fidx The function index. + * \param args The arguments to the function. + * \return The object representing the result. + */ + RegType Invoke(Index fidx, const std::vector& args); + /*! + * \brief Read a VM register and cast it to int64_t. + * \param reg The register to read from. + * \return The read scalar. + */ + int64_t LoadScalarInt(RegName reg) const; + /*! \brief Run VM dispatch loop. */ + void RunLoop(); + /*! + * \brief Run call instruction. + * \param curr_frame The current frame. + * \param inst The call instruction. + */ + inline void RunInstrCall(VMFrame* curr_frame, Instruction inst); + + /*! + * \brief Set inputs to a function. + * \param func_name The function name. + * \param args args[offset:] are arguments to the function. If the arguments are not of the + * correct device for the function, they will be copied to the device. + * \param offset Starting offset of the arguments in \p args. + * \note This interface works when using VM over RPC by internally converting NDArray in + * the arguments to DLTensor, which is supported in RPC where remote could only have a minimal C + * runtime. + */ + void SetInput(std::string func_name, TVMArgs args, int offset); + + /*! + * \brief Set a function argument with a given index to an input tensor. + * \param func_args the function arguments. + * \param inp_tensor some input tensor (not necessarily DLTensor). When it's an NDArray or a list + * of NDArray, they will be converted. + * \param index The input tensor index in the function arguments. + * \param dev device to copy to if needed. + */ + void SetInputTensorWithIndex(std::vector& func_args, const TVMArgValue& inp_tensor, + int index, Device dev); + + /*! + * \brief Look up whether the VM has a function by the given name. + * \param func_name the function's name + * \return The function, if it exists. Logs a fatal error if not. + */ + VMFunction LookupVMFunction(const std::string& func_name); + + /*! + * \brief Look up whether the VM has outputs for the given function. + * \param func_name the function's name + * \return The output, if it exists. Logs a fatal error if not. + */ + RegType LookupVMOutput(const std::string& func_name); + + private: + /*! \brief The loaded executable. */ + ObjectPtr exec_; + /*! + * \brief Internal function table cache to speedup execution. + * \note This is used to cache functions so we do not need + * to look up by name every time. + * It does mean that the definition of the function + * cannot change when the vm get loaded. + */ + std::vector func_table_; + /*! + * \brief The current stack of call frames. + * \note: Use unique ptr to avoid re-allocation and copy when frames_ get resized. + */ + std::vector> frames_; + /*! \brief The virtual machine PC. */ + Index pc_{0}; + /*! \brief The special return register. */ + RegType return_value_; + /*! \brief The global constant pool */ + std::vector constants; + /*! \brief The function name to input register mapping. */ + std::unordered_map> inputs_; + /*! \brief The function name to output register. */ + std::unordered_map outputs_; + /*! \brief A store of closures created by `save_function`. */ + std::unordered_map saved_closures_; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_VM_H_ diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 887981ccff..dacfc361a6 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -38,12 +38,17 @@ namespace ir { */ class IRModuleFrameNode : public IRBuilderFrameNode { public: - Array global_vars; - Array functions; + /*! \brief A map from string names to global variables that ensures global uniqueness. */ + Map global_var_map; + /*! + * \brief A map from GlobalVar to all global functions. + * \note Only defined functions are in the map, while declared functions are not included. + */ + Map functions; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_vars); + v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); } diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index f0e7cc6f5c..10996a7b10 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -37,6 +37,21 @@ namespace ir { */ TVM_DLL IRModuleFrame IRModule(); +/*! + * \brief Declare a Function without given the specific function implementation. + * \note It is usually used in cross-function call. And we can specify the function by `DefFunction` + * \param func_name The function unique name. + * \return The corresponding GlobalVar. + */ +TVM_DLL GlobalVar DeclFunction(const String& func_name); + +/*! + * \brief Define the function which is declared before. + * \param func_name The function unique name. + * \param func The given function implementation + */ +TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); + } // namespace ir } // namespace ir_builder } // namespace script diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h new file mode 100644 index 0000000000..a1e908aef3 --- /dev/null +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +/*! \brief The base ir_builder frame for the relax dialect. */ +class RelaxFrameNode : public IRBuilderFrameNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); } + + static constexpr const char* _type_key = "script.ir_builder.relax.RelaxFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(RelaxFrameNode, IRBuilderFrameNode); +}; + +class RelaxFrame : public IRBuilderFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, IRBuilderFrame, RelaxFrameNode); + + protected: + RelaxFrame() = default; +}; + +/*! \brief The base ir_builder frame for frames with SeqExpr + i.e. Functions, If branches + */ +class SeqExprFrameNode : public RelaxFrameNode { + public: + /*! \brief The binding blocks inside the frame. */ + Array binding_blocks; + /*! \brief The frame output expr. `NullOpt` when undefined. */ + Optional output; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("binding_blocks", &binding_blocks); + v->Visit("output", &output); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.SeqExprFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(SeqExprFrameNode, RelaxFrameNode); + + public: + void ExitWithScope() override; +}; + +class SeqExprFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SeqExprFrame, RelaxFrame, SeqExprFrameNode); +}; + +/*! \brief The ir_builder frame for the relax function. */ +class FunctionFrameNode : public SeqExprFrameNode { + public: + /*! + * \brief The function name. + * \note The name will not be specified in constructor, so it is "Optional", + * However, we must specify the name by `R.func_name` before exit this frame. + */ + Optional name; + /*! \brief The function params. */ + Array params; + /*! + * \brief The function return type. + * \note Usually the function return type can be deduced by the function body. + * But we can use this field to specify a more "accurate" return type. + * i.e. If the `ret_type` is None, try to use the deduced type from body + * If the `ret_type` is not None, check the deduced type is a base type of the given one. + */ + Optional ret_type; + /*! \brief The function attributes. */ + Map attrs; + /*! \brief The block builder to create Relax function. */ + tvm::relax::BlockBuilder block_builder; + + void VisitAttrs(tvm::AttrVisitor* v) { + SeqExprFrameNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("params", ¶ms); + v->Visit("ret_type", &ret_type); + v->Visit("attrs", &attrs); + v->Visit("binding_blocks", &binding_blocks); + v->Visit("output", &output); + // `block_builder` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.relax.FunctionFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode); + + public: + void ExitWithScope() final; +}; + +class FunctionFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionFrame, SeqExprFrame, FunctionFrameNode); +}; + +/*! \brief The ir_builder frame for relax binding blocks. */ +class BlockFrameNode : public RelaxFrameNode { + public: + /*! \brief The flag that indicates whether the block is a dataflow block. */ + bool is_dataflow; + /*! \brief The variables emitted in this block. */ + Array emitted_vars; + /*! + * \brief (Only used for a dataflow block.) A boolean indicating if the dataflow block is ended of + * construction. If it is true, any new binding trying to be emitted into this block will cause an + * error. + */ + bool block_ended; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("is_dataflow", &is_dataflow); + v->Visit("emitted_vars", &emitted_vars); + v->Visit("block_ended", &block_ended); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.BlockFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, RelaxFrameNode); + + public: + void EnterWithScope() final; + void ExitWithScope() final; +}; + +class BlockFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, RelaxFrame, BlockFrameNode); +}; + +/*! + * \brief A frame that represents if statement. + * + * \sa IfFrame + */ +class IfFrameNode : public RelaxFrameNode { + public: + /*! \brief The condition of the if statement. */ + tvm::relax::Expr condition; + /*! \brief The Bindings in the true branch. */ + Optional then_expr; + /*! \brief The Bindings in the false branch. */ + Optional else_expr; + /*! \brief The Binding var. */ + tvm::relax::Var var; + /*! \brief The binding var name. */ + String var_name; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + v->Visit("then_expr", &then_expr); + v->Visit("else_expr", &else_expr); + v->Visit("var", &var); + v->Visit("var_name", &var_name); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.IfFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, RelaxFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to IfFrameNode. + * + * \sa IfFrameNode + */ +class IfFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, RelaxFrame, IfFrameNode); +}; + +/*! + * \brief A frame that represents then. + * + * \sa ThenFrame + */ +class ThenFrameNode : public SeqExprFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.relax.ThenFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, SeqExprFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ThenFrameNode. + * + * \sa ThenFrameNode + */ +class ThenFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, SeqExprFrame, ThenFrameNode); +}; + +/*! + * \brief A frame that represents else. + * + * \sa ElseFrame + */ +class ElseFrameNode : public SeqExprFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.relax.ElseFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, SeqExprFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ElseFrameNode. + * + * \sa ElseFrameNode + */ +class ElseFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, SeqExprFrame, ElseFrameNode); +}; + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h new file mode 100644 index 0000000000..2f31d220cd --- /dev/null +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ + +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +////////////////////////////// Tensor Type ////////////////////////////// + +/*! \brief A temporary Tensor type for `R.Tensor` in ir_builder. */ +class TensorTypeNode : public runtime::Object { + public: + /*! \brief The type, usually is DynTensorType */ + tvm::relax::DynTensorType type; + /*! \brief The shape, which is optional. */ + Optional shape; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("type", &type); + v->Visit("shape", &shape); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.TensorType"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, runtime::Object); +}; + +class TensorType : public runtime::ObjectRef { + public: + TVM_DLL explicit TensorType(tvm::relax::DynTensorType type, Optional shape); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorType, ObjectRef, TensorTypeNode); +}; + +/*! + * \brief Create a TensorType for a DynTensor. + * \param shape The shape of the tensor. It's runtime dependent if `shape` is None. + * \param dtype The element data type of the tensor. It's runtime dependent if `dtype` is None. + * \param ndim The number of dimensions of the tensor. It's runtime dependent if `ndim` is -1. + * \return The TensorType that is only used in ir_builder. + */ +TVM_DLL TensorType Tensor(Optional> shape, DataType dtype, int ndim = -1); + +/////////////////////////////// Function //////////////////////////////// + +/*! + * \brief Start a function frame. + * \return The created ir_builder Function frame. + */ +TVM_DLL FunctionFrame Function(); + +/*! + * \brief Add a parameter to the last function frame. + * \param name The name of the parameter. + * \param type The type of the parameter. + * \param shape The shape of the parameter. + * \return The created function parameter var. + */ +TVM_DLL tvm::relax::Var Arg(const String& name, const Type& type, + const tvm::relax::ShapeExpr& shape); + +/*! + * \brief Specify the name of the last function frame. + * \param name The function name. + */ +TVM_DLL void FuncName(const String& name); + +/*! + * \brief Specify the attrs of the last function frame. + * \param attrs The function attrs. + */ +TVM_DLL void FuncAttrs(Map attrs); + +/*! + * \brief Specify the return type of the last function frame. + * \param ret_type The return type. Note: it's a standard `tvm::Type` instead of TensorType. + */ +TVM_DLL void FuncRetType(tvm::Type ret_type); + +/*! + * \brief Specify the return value of the last function frame. + * \param value The return value. + */ +TVM_DLL void FuncRetValue(const tvm::relax::Expr& value); + +///////////////////////////// BindingBlock ////////////////////////////// + +/*! + * \brief Start a binding block frame. + * \return The created ir_builder Block frame. + */ +TVM_DLL BlockFrame BindingBlock(); + +/*! + * \brief Start a dataflow binding block frame. + * \return The created ir_builder Block frame. + */ +TVM_DLL BlockFrame Dataflow(); + +/*! + * \brief Expose the dataflow block output variables as global ones + * \param vars The output variables of a dataflow block + */ +TVM_DLL void DataflowBlockOutput(const Array& vars); + +////////////////////////////// Bindings //////////////////////////////// + +/*! + * \brief Emit a binding to the last binding block frame. + * \param value The right side value of the bindings to be emitted. + * \param is_dataflow_var A boolean indicating if the emitted binding variable is a dataflow + * variable. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var Emit(const tvm::relax::Expr& value, bool is_dataflow_var); + +/*! + * \brief Emit a match_shape binding to the last binding block frame. + * \param value The value of the MatchShape to be emitted. + * \param pattern The pattern of the MatchShape to be emitted. + * \param emit_var A boolean indicating if the MatchShape contains the emitted variable. + * \param is_dataflow_var A boolean indicating if the emitted variable is a dataflow variable when + * `emit_var` is true. When `emit_var` is false, the value of this flag will be ignored. + * \return The emitted var if `emit_var` is true. Otherwise, return `NullOpt`. + */ +TVM_DLL Optional EmitMatchShape(const tvm::relax::Expr& value, // + const Array& pattern, // + bool emit_var, // + bool is_dataflow_var); + +///////////////////////////// Type Deduce ////////////////////////////// + +/*! + * \brief Annotate and check the type and shape of relax var. + * \param var The input var to be annotated. + * \param anno_type The annotated type. + * \param anno_shape The annotated shape, which can be undefined. + * \note This function will check if the type of var is compatible with the annotated type. + * And we annotate to the var with more detailed type. + */ +TVM_DLL void AnnotateTypeShape(const tvm::relax::Var& var, const Type& anno_type, + const Optional& anno_shape); + +///////////////////////////// If Then Else ///////////////////////////// + +/*! + * \brief Create an if statement. + * \param condition The condition of if statement. + * \return The result IfFrame. + */ +IfFrame If(tvm::relax::Expr condition); +/*! + * \brief Create a then. + * \return The result ThenFrame. + */ +ThenFrame Then(); +/*! + * \brief Create an else. + * \return The result ElseFrame. + */ +ElseFrame Else(); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 2c50f3c315..f5753afa56 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -182,7 +182,7 @@ class PlaceholderOpNode : public OperationNode { } static constexpr const char* _type_key = "PlaceholderOp"; - TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); + TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; /*! diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e31919fbd2..1d778e3386 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -56,7 +56,7 @@ using tvm::transform::Sequential; */ TVM_DLL Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required); + int opt_level, String name, tvm::Array required, bool traceable = false); /*! * \brief Inject prefetch instructions into stmt. diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index 59430eee83..9949a1a497 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -130,8 +131,8 @@ class BufferInfo : public ObjectRef { * for memory planning algorithms. */ struct BufferInfoAnalysisNode : public Object { - /*! \brief The BufferInfo object and its associated TIR statement */ - Map buffer_info_stmts; + /*! \brief The BufferInfo object and its associated TIR statement/Relax expr */ + Map buffer_info_stmts; /*! \brief This represent maximum amount of memory being used at * any point of time in the inference. This value is largely the * best allocation an algorithm could achieve. Due to @@ -159,7 +160,8 @@ struct BufferInfoAnalysisNode : public Object { class BufferInfoAnalysis : public ObjectRef { public: - TVM_DLL BufferInfoAnalysis(Map buffer_info_stmts, Integer memory_pressure); + TVM_DLL BufferInfoAnalysis(Map buffer_info_stmts, + Integer memory_pressure); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfoAnalysis, ObjectRef, BufferInfoAnalysisNode); }; @@ -240,7 +242,7 @@ class AllocatedPoolInfo : public ObjectRef { * * \param buffer_info_map IR-bound BufferInfo map */ -Array ConvertToArrayOfBufferInfo(const Map& buffer_info_map); +Array ConvertToArrayOfBufferInfo(const Map& buffer_info_map); /*! * \brief Calculate workspace required to execute a IRModule with main expressed in TIR @@ -289,7 +291,7 @@ Integer CalculateExtentsSize(const AllocateConstNode* op); * \param buffer_info_to_pool_allocation the map of BufferInfo objects to PoolAllocation objects */ Map AssignStmtPoolAllocations( - const Map& buffer_info_to_stmt, + const Map& buffer_info_to_stmt, const Map& buffer_info_to_pool_allocation); /*! diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index ed78cd689e..1f1e4edcd9 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -111,7 +111,7 @@ def _traverse_expr(node): else: node_entry["inputs"].append([in_node_idx, 0, 0]) infer_out = _infer_type(node) - out_type = infer_out._checked_type_ + out_type = infer_out.checked_type_ if isinstance(out_type, TensorType): node_entry["types"].append(out_type) elif isinstance(out_type, TupleType): diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index b69382fe12..cb5ae75844 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -23,7 +23,9 @@ from typing import Union import tvm +from tvm import relax from tvm import rpc as _rpc +from tvm.contrib import utils import tvm.contrib.hexagon as hexagon from tvm.relay.backend.executor_factory import ( ExecutorFactoryModule, @@ -262,13 +264,13 @@ def get_graph_debug_executor( graph_json, graph_debug_mod, self.device, dump_root=str(dump_root) ) - def get_executor_from_factory(self, module: ExecutorFactoryModule): + def get_executor_from_factory(self, module: Union[ExecutorFactoryModule, relax.vm.Executable]): """Create a local GraphModule which consumes a remote libmod. Parameters ---------- - module : ExecutorFactoryModule + module : Union[ExecutorFactoryModule, relax.vm.Executable] The module to upload to the remote session and load. @@ -277,6 +279,8 @@ def get_executor_from_factory(self, module: ExecutorFactoryModule): return self._aot_executor_from_factory(module) if isinstance(module, GraphExecutorFactoryModule): return self._graph_executor_from_factory(module) + if isinstance(module, relax.vm.Executable): + return self._relax_vm_executable_executor(module) raise TypeError(f"Unsupported executor type: {type(module)}") @@ -328,6 +332,38 @@ def _graph_executor_from_factory( """ return self.get_graph_executor(module.get_graph_json(), module.get_lib()) + def _relax_vm_executable_executor( + self, + vm_exec: relax.vm.Executable, + ): + """Create a local TVM module which consumes a remote vm executable. + + Paramters + --------- + + vm_exec : relax.vm.Executable + The Relax VM Executable to upload to the remote and load. This will typically be the + output of `relax.vm.build`. + + Returns + ------- + TVMModule : + TVM module object + """ + assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use" + + temp_dir = utils.tempdir() + path_exec = temp_dir.relpath("exec.so") + + vm_exec.mod.export_library( + path_exec, + fcompile=hexagon.create_aot_shared, + hexagon_arch="v68", + ) + + path = self.upload(path_exec, "exec.so") + return self._rpc.get_function("tvm.hexagon.load_module")(str(path)) + def _aot_executor_from_factory( self, module: Union[str, pathlib.Path, AOTExecutorFactoryModule], diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index 272598f7c3..b214a1b9dd 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -85,7 +85,7 @@ def infer_type(node): def call_node_infer_type(node): """infer the output types of call node""" infer_out = infer_type(node) - out_type = infer_out._checked_type_ + out_type = infer_out.checked_type_ if isinstance(out_type, TensorType): types = [out_type] elif isinstance(out_type, TupleType): diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index c6b30d38ed..f7ea24ee10 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -248,7 +248,7 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): The left operand. rhs : Object - The left operand. + The right operand. map_free_vars : bool Whether or not shall we map free vars that does diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 43cba1a835..1ea02a7f82 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -47,9 +47,20 @@ def checked_type(self): """ ret = self._checked_type_ if ret is None: - raise ValueError("The type checker has not populated" " the checked_type for this node") + raise ValueError("The type checker has not populated the checked_type for this node") return ret + @property + def shape(self): + """Get the shape of tvm.relay.Expr. + + Returns + ------- + shape : tvm.ir.RelayExpr + The expression that represents the shape. + """ + return _ffi_api.RelayExprShape(self) + @tvm._ffi.register_object("GlobalVar") class GlobalVar(RelayExpr): diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index c3f1bf5f56..a5a572b4b6 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -15,10 +15,14 @@ # specific language governing permissions and limitations # under the License. """Function defintiions.""" +from __future__ import annotations +from typing import Union, Dict from enum import IntEnum import tvm.runtime +from tvm.runtime.object import Object from .expr import RelayExpr +from .attrs import DictAttrs from . import _ffi_api @@ -38,7 +42,7 @@ def attrs(self): """Return the attrs member of the function.""" return _ffi_api.BaseFunc_Attrs(self) - def with_attr(self, attr_key_or_dict, attr_value=None): + def with_attr(self, attr_key_or_dict, attr_value=None) -> BaseFunc: """Create a new copy of the function and update the attribute. Parameters @@ -51,7 +55,7 @@ def with_attr(self, attr_key_or_dict, attr_value=None): Returns ------- - func : Function + func : BaseFunc A new copy of the function """ # make sure we first copy so that we can safely do copy on write @@ -66,3 +70,35 @@ def with_attr(self, attr_key_or_dict, attr_value=None): return _ffi_api.BaseFuncWithAttr( res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value) ) + + def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> BaseFunc: + """Copy the IRModule and add the given attribute map to it. + Parameters + ---------- + attr_map: Union[DictAttrs, Dict[str, Object]] + The attribute map + Returns + ------- + func : BaseFunc + A new copy of the function + """ + if isinstance(attr_map, tvm.ir.DictAttrs): + attr_map = attr_map._dict() + + return _ffi_api.BaseFuncWithAttrs(self, attr_map) + + def without_attr(self, attr_key: str) -> BaseFunc: + """Create a new copy of the function with an attribute without provided key. + + Parameters + ---------- + attr_key : str + The attribute key to delete from the attrubte pairs. + + + Returns + ------- + func : BaseFunc + A new copy of the function + """ + return _ffi_api.BaseFuncWithoutAttr(self, attr_key) diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 06537e2cdc..36656a5b4a 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -15,13 +15,17 @@ # specific language governing permissions and limitations # under the License. """IRModule that holds the functions and type definitions.""" -from typing import Optional - +from __future__ import annotations +from typing import Optional, Union, Dict +import ast from tvm._ffi.base import string_types import tvm._ffi +from tvm.runtime.object import Object from .base import Node from . import expr as _expr +from .attrs import DictAttrs +from ..ir.function import BaseFunc from . import type as _ty from . import _ffi_api @@ -38,7 +42,7 @@ class IRModule(Node): Map of global var to BaseFunc """ - def __init__(self, functions=None, type_definitions=None): + def __init__(self, functions=None, type_definitions=None, attrs=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -61,7 +65,17 @@ def __init__(self, functions=None, type_definitions=None): raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") mapped_type_defs[k] = v type_definitions = mapped_type_defs - self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) + + attrs = None if not attrs else attrs + if attrs is not None: + attrs = ast.literal_eval(str(attrs)) + attrs = tvm.ir.make_node("DictAttrs", **attrs) + self.__init_handle_by_constructor__( + _ffi_api.IRModule, + functions, + type_definitions, + attrs, + ) def __setitem__(self, var, val): """Add a mapping to the module. @@ -77,7 +91,7 @@ def __setitem__(self, var, val): return self._add(var, val, True) def _add(self, var, val, update=True): - if isinstance(val, _expr.RelayExpr): + if isinstance(val, (_expr.RelayExpr, BaseFunc)): if isinstance(var, string_types): if _ffi_api.Module_ContainGlobalVar(self, var): var = _ffi_api.Module_GetGlobalVar(self, var) @@ -281,6 +295,7 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: def show(self, style: Optional[str] = None) -> None: """ A sugar for print highlighted TVM script. + Parameters ---------- style : str, optional @@ -307,7 +322,18 @@ def get_attr(self, attr_key): return _ffi_api.Module_GetAttr(self, attr_key) - def with_attr(self, attr_key, attr_value): + def get_attrs(self): + """Get the meta_data attributes. + + Returns + ------- + meta_data : DictAttrs + meta_data attributes + """ + + return _ffi_api.Module_GetAttrs(self) + + def with_attr(self, attr_key, attr_value) -> IRModule: """Copy the IRModule and add an attribute to it. Parameters @@ -325,3 +351,33 @@ def with_attr(self, attr_key, attr_value): """ return _ffi_api.Module_WithAttr(self, attr_key, attr_value) + + def without_attr(self, attr_key: str) -> IRModule: + """Copy the IRModule and remove an attribute key and its associated value. + Parameters + ---------- + attr_key : str + The attribute key. + Returns + ------- + mod : IRModule + A new copy of the IRModule without the attribute + """ + + return _ffi_api.Module_WithoutAttr(self, attr_key) + + def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> IRModule: + """Copy the IRModule and add the given attribute map to it. + Parameters + ---------- + attr_map: Union[DictAttrs, Dict[str, Object]] + The attribute map + Returns + ------- + mod : IRModule + A new copy of the IRModule with the attribute + """ + if isinstance(attr_map, tvm.ir.DictAttrs): + attr_map = attr_map._dict() + + return _ffi_api.Module_WithAttrs(self, attr_map) diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 17995bfa78..edc38aae92 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -22,7 +22,6 @@ import tvm._ffi import tvm.runtime - from . import _ffi_transform_api @@ -45,8 +44,10 @@ class PassInfo(tvm.runtime.Object): The list of passes that are required by a certain pass. """ - def __init__(self, opt_level, name, required=None): - self.__init_handle_by_constructor__(_ffi_transform_api.PassInfo, opt_level, name, required) + def __init__(self, opt_level, name, required=None, traceable=False): + self.__init_handle_by_constructor__( + _ffi_transform_api.PassInfo, opt_level, name, required, traceable + ) @tvm._ffi.register_object("transform.PassContext") @@ -70,6 +71,20 @@ class PassContext(tvm.runtime.Object): config : Optional[Dict[str, Object]] Additional configurations for specific passes. + + trace: Optional[relax.tuning.Trace] + Initial trace for trace mode. + + trace_stack: Optional[List[relax.tuning_api.Trace]] + Initial trace stack for trace mode. + + make_traceable: Optional[List[str]] + List of passes to make traceable. + + num_evals: int + initial number of evaluations conducted in the pipeline. + + tuning_api_database: Optional[relax.tuning_api.JSONDatabase] """ def __init__( @@ -79,6 +94,11 @@ def __init__( disabled_pass=None, instruments=None, config=None, + trace=None, + trace_stack=None, + make_traceable=None, + num_evals=0, + tuning_api_database=None, ): required = list(required_pass) if required_pass else [] if not isinstance(required, (list, tuple)): @@ -92,9 +112,25 @@ def __init__( if not isinstance(instruments, (list, tuple)): raise TypeError("instruments is expected to be the type of " + "list/tuple/set.") + # Convert to Map + # TODO(sunggg): Replace this to Set equivalent if exists + make_traceable = {name: True for name in make_traceable} if make_traceable else None + + if not trace_stack: + trace_stack = [trace] if trace else [] + config = config if config else None self.__init_handle_by_constructor__( - _ffi_transform_api.PassContext, opt_level, required, disabled, instruments, config + _ffi_transform_api.PassContext, + opt_level, + required, + disabled, + instruments, + config, + trace_stack, + make_traceable, + num_evals, + tuning_api_database, ) def __enter__(self): @@ -131,6 +167,47 @@ def list_configs(): """ return _ffi_transform_api.ListConfigs() + def push_trace(self, trace): + """Push a trace into the stack.""" + return _ffi_transform_api.PushTrace(self, trace) + + def pop_trace(self, return_current=True): + """Pop a topmost trace from the stack. + Returns + ------- + Trace : Optional[relax.tuning.Trace] + """ + if return_current: + cur_trace = self.get_current_trace() + _ffi_transform_api.PopTrace(self) + return cur_trace + + return _ffi_transform_api.PopTrace(self) + + def get_trace_stack(self): + """Get the current trace stack.""" + return _ffi_transform_api.GetTraceStack(self) + + def get_trace_stack_size(self): + """Get the size of current stack.""" + return _ffi_transform_api.GetTraceStackSize(self) + + def get_current_trace(self): + """Get the trace on the top of the stack.""" + return _ffi_transform_api.GetCurrentTrace(self) + + def set_num_evals(self, num: int): + """Set the number of evaluations conducted in the pipeline.""" + return _ffi_transform_api.SetNumEvals(self, num) + + def inc_num_evals(self, num: int): + """Increment the number of evaluations conducted in the pipeline.""" + return _ffi_transform_api.IncNumEvals(self, num) + + def get_tuning_api_database(self): + """Get tuning api database.""" + return _ffi_transform_api.GetTuningAPIDatabase(self) + @tvm._ffi.register_object("transform.Pass") class Pass(tvm.runtime.Object): @@ -199,7 +276,7 @@ class Sequential(Pass): The list of passes that the sequential pass is dependent on. """ - def __init__(self, passes=None, opt_level=0, name="sequential", required=None): + def __init__(self, passes=None, opt_level=0, name="sequential", required=None, traceable=False): passes = passes if passes else [] if not isinstance(passes, (list, tuple)): raise TypeError("passes must be a list of Pass objects.") @@ -209,7 +286,7 @@ def __init__(self, passes=None, opt_level=0, name="sequential", required=None): raise TypeError("Required is expected to be the type of list/tuple.") self.__init_handle_by_constructor__( - _ffi_transform_api.Sequential, passes, opt_level, name, required + _ffi_transform_api.Sequential, passes, opt_level, name, required, traceable ) @@ -245,7 +322,7 @@ def __getattr__(self, name): return PyModulePass -def module_pass(pass_func=None, opt_level=None, name=None, required=None): +def module_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False): """Decorate a module pass. This function returns a callback when pass_func is provided. @@ -270,6 +347,9 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None): required : Optional[List[str]] The list of passes that the module pass is dependent on. + traceable: Boolean + Boolean variable whether the module pass is traceable + Returns ------- create_module_pass : Union[Callable, ModulePass] @@ -337,7 +417,7 @@ def transform(mod, ctx): def create_module_pass(pass_arg): """Internal function that creates a module pass""" fname = name if name else pass_arg.__name__ - info = PassInfo(opt_level, fname, required) + info = PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_module_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 4fe28f1d72..3748b7dcec 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -19,6 +19,7 @@ import tvm import tvm._ffi +from . import Span from .base import Node from . import _ffi_api @@ -166,8 +167,8 @@ class TupleType(Type): The fields in the tuple """ - def __init__(self, fields): - self.__init_handle_by_constructor__(_ffi_api.TupleType, fields) + def __init__(self, fields, span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.TupleType, fields, span) @tvm._ffi.register_object("TypeConstraint") diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 04acdc9d4a..4d38cb18d0 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -24,6 +24,7 @@ measure_callback, mutator, postproc, + relax_integration, relay_integration, runner, schedule_rule, diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py new file mode 100644 index 0000000000..a82d899685 --- /dev/null +++ b/python/tvm/meta_schedule/relax_integration.py @@ -0,0 +1,352 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta schedule integration with high-level IR""" +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +# isort: off +from typing_extensions import Literal + +# isort: on + +from tvm._ffi import get_global_func, register_func +from tvm.ir import IRModule +from tvm.ir.transform import PassContext +from tvm.runtime import NDArray +from tvm.target import Target +from tvm.tir.expr import IntImm + +from .builder import Builder +from .cost_model import CostModel +from .database import Database +from .extracted_task import ExtractedTask +from .logging import get_loggers_from_work_dir +from .measure_callback import MeasureCallback +from .runner import Runner +from .search_strategy import SearchStrategy +from .space_generator import SpaceGenerator +from .task_scheduler import TaskScheduler +from .tune import tune_tasks +from .tune_context import TuneContext +from .utils import fork_seed + +if TYPE_CHECKING: + from tvm import relax + +_extract_task_func = get_global_func( # pylint: disable=invalid-name + "relax.backend.MetaScheduleExtractTask", + allow_missing=False, +) + + +def extract_tasks( + mod: Union[IRModule, "relax.Function"], + target: Target, + params: Optional[Dict[str, NDArray]] = None, +) -> List[ExtractedTask]: + """Extract tuning tasks from a relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + target : tvm.target.Target + The compilation target + + Returns + ------- + tasks: List[ExtractedTask] + The tasks extracted from this module + """ + # pylint: disable=import-outside-toplevel + from tvm.relax.expr import Function as RelaxFunc + from tvm.relax.transform import BindParams + + # pylint: enable=import-outside-toplevel + if isinstance(mod, RelaxFunc): + mod = IRModule({"main": mod}) + if not isinstance(target, Target): + target = Target(target) + if params: + mod = BindParams("main", params)(mod) + return list(_extract_task_func(mod, target)) + + +def extracted_tasks_to_tune_contexts( + extracted_tasks: List[ExtractedTask], + work_dir: str, + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + num_threads: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, +) -> Tuple[List[TuneContext], List[float]]: + """Convert ExtractedTask to TuneContext. + + Parameters + ---------- + tasks : List[ExtractedTask] + The tasks to be converted + work_dir : str + The working directory to store logs and databases + space : SpaceGenerator.SpaceGeneratorType + The space generator to use. + strategy : SearchStrategy.SearchStrategyType + The search strategy to use. + num_threads : Union[Literal["physical", "logical"], int] + The number of threads to use in multi-threaded search algorithm. + seed : Optional[int] + The random seed to use. + + Returns + ------- + tasks : List[TuneContext] + The converted tasks + task_weights : List[float] + The weights of the tasks + """ + tasks: List[TuneContext] = [] + task_weights: List[float] = [] + for task, logger, rand_state in zip( + extracted_tasks, + get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]), + fork_seed(seed, n=len(extracted_tasks)), + ): + tasks.append( + TuneContext( + mod=task.dispatched[0], + target=task.target, + space_generator=space, + search_strategy=strategy, + task_name=task.task_name, + logger=logger, + rand_state=rand_state, + num_threads=num_threads, + ).clone() + ) + task_weights.append(task.weight) + return tasks, task_weights + + +def tune_relax( + mod: Union[IRModule, "relax.Function"], + params: Dict[str, NDArray], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + seed: Optional[int] = None, +) -> Database: + """Tune a Relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + target : Union[Target, str] + The compilation target + work_dir : str + The working directory to store the tuning records + max_trials_global : int + The maximum number of trials to run + max_trials_per_task : Optional[int] + The maximum number of trials to run for each task + num_trials_per_iter : int + The number of trials to run per iteration + builder : BuilderType + The builder to use + runner : RunnerType + The runner to use + database : DatabaseType + The database to use + cost_model : CostModelType + The cost model to use + measure_callbacks : CallbackListType + The measure callbacks to use + task_scheduler : TaskSchedulerType + The task scheduler to use + space : SpaceGeneratorType + The space generator to use + strategy : SearchStrategyType + The search strategy to use + seed : Optional[int] + The random seed + + Returns + ------- + database : Database + The database that contains the tuning records + """ + tasks, task_weights = extracted_tasks_to_tune_contexts( + extracted_tasks=extract_tasks(mod, target, params), + work_dir=work_dir, + space=space, + strategy=strategy, + seed=seed, + ) + return tune_tasks( + tasks=tasks, + task_weights=task_weights, + work_dir=work_dir, + max_trials_global=max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + ) + + +@register_func("tvm.meta_schedule.tune_relax") +def _tune_relax( + mod: Union[IRModule, "relax.Function"], + params: Dict[str, NDArray], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + seed: Optional[int] = None, +) -> Database: + """Interface with tuning api to tune a Relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + target : Union[Target, str] + The compilation target + work_dir : str + The working directory to store the tuning records + max_trials_global : int + The maximum number of trials to run + max_trials_per_task : Optional[int] + The maximum number of trials to run for each task + num_trials_per_iter : int + The number of trials to run per iteration + builder : BuilderType + The builder to use + runner : RunnerType + The runner to use + database : DatabaseType + The database to use + cost_model : CostModelType + The cost model to use + measure_callbacks : CallbackListType + The measure callbacks to use + task_scheduler : TaskSchedulerType + The task scheduler to use + space : SpaceGeneratorType + The space generator to use + strategy : SearchStrategyType + The search strategy to use + seed : Optional[int] + The random seed + + Returns + ------- + ret_mod : IRModule + IRModule + """ + if isinstance(max_trials_global, IntImm): + max_trials_global = int(max_trials_global) + + tune_relax( + mod, + params, + target, + work_dir, + max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + strategy=strategy, + seed=seed, + ) + # Return original IRModule + # This pass only makes optimization decision + return mod + + +def compile_relax( + database: Database, + mod: IRModule, + target: Union[Target, str], + params: Optional[Dict[str, NDArray]], +) -> "relax.vm.Executable": + """Compile a relax program with a MetaSchedule database. + + Parameters + ---------- + database : Database + The database to use + mod : IRModule + The Relax program to be compiled + target : tvm.target.Target + The compilation target + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + + Returns + ------- + lib : relax.vm.Executable + The built runtime module or vm Executable for the given relax workload. + """ + # pylint: disable=import-outside-toplevel + from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase + from tvm.relax.vm import build as relax_build + + # pylint: enable=import-outside-toplevel + if not isinstance(target, Target): + target = Target(target) + if params: + mod = BindParams("main", params)(mod) + + with target, database, PassContext(opt_level=3): + relax_mod = MetaScheduleApplyDatabase()(mod) + relax_ex = relax_build(relax_mod, target=target) + return relax_ex diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index 975987ebcb..db88228a43 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -22,7 +22,9 @@ # isort: on from tvm import ir, tir +from tvm._ffi import register_func from tvm.target import Target +from tvm.tir.expr import IntImm from .builder import Builder from .cost_model import CostModel @@ -128,6 +130,93 @@ def tune_tir( ) +@register_func("tvm.meta_schedule.tune_tir") +def _tune_tir( + mod: Union[ir.IRModule, tir.PrimFunc], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "round-robin", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + task_name: str = "main", + num_threads: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, +) -> Database: + """Interface with tuning api to tune a TIR program. + + Parameters + ---------- + mod : Union[ir.IRModule, tir.PrimFunc] + The TIR function to tune. + target : Union[str, Target] + The target to tune for. + work_dir : str + The working directory. + max_trials_global : int + The maximum number of trials to run globally. + num_trials_per_iter : int + The number of trials to run per iteration + builder : Builder.BuilderType + The builder. + runner : Runner.RunnerType + The runner. + database : Database.DatabaseType + The database. + cost_model : CostModel.CostModelType + The cost model. + measure_callbacks : MeasureCallback.CallbackListType + The measure callbacks. + task_scheduler : TaskScheduler.TaskSchedulerType + The task scheduler. + space : SpaceGenerator.SpaceGeneratorType + The space generator. + strategy : SearchStrategy.SearchStrategyType + The search strategy. + task_name : str + The name of the task. + num_threads : Union[Literal["physical", "logical"], int] + The number of threads to use. + seed : Optional[int] + The seed for the random number generator. + + Returns + ------- + ret_mod : IRModule + IRModule + """ + if isinstance(max_trials_global, IntImm): + max_trials_global = int(max_trials_global) + tune_tir( + mod, + target, + work_dir, + max_trials_global, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + strategy=strategy, + task_name=task_name, + num_threads=num_threads, + seed=seed, + ) + # Return original IRModule + # This pass only makes optimization decision + return mod + + def compile_tir( database: Database, mod: Union[ir.IRModule, tir.PrimFunc], diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 38a46ebe75..8c4f4ce864 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -24,7 +24,7 @@ # isort: on from tvm import IRModule -from tvm._ffi import register_object +from tvm._ffi import register_object, register_func from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule @@ -41,6 +41,7 @@ from .space_generator import SpaceGenerator +@register_func("tvm.meta_schedule.normalize_mod") def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: """Normalize the input to an IRModule""" if isinstance(mod, PrimFunc): diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index eb3c643760..483f8fa8cd 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -75,14 +75,27 @@ def _extract(inst: type, name: str): def method(*args, **kwargs): return getattr(inst, name)(*args, **kwargs) - if getattr(base, name) is getattr(cls, name) and name != "__str__": - # for task scheduler return None means calling default function - # otherwise it will trigger a TVMError of method not implemented - # on the c++ side when you call the method, __str__ not required - return None - return method + for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]): + # extract functions that differ from the base class + if not hasattr(base_cls, name): + continue + if getattr(base_cls, name) is getattr(inherit_cls, name) and name != "__str__": + continue + return method + + # for task scheduler return None means calling default function + # otherwise it will trigger a TVMError of method not implemented + # on the c++ side when you call the method, __str__ not required + return None assert isinstance(cls.__base__, type) + if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": + raise TypeError( + ( + f"Inheritance from a decorated object `{cls.__name__}` is not allowed. " + f"Please inherit from `{cls.__name__}._cls`." + ) + ) assert hasattr( cls, "_tvm_metadata" ), "Please use the user-facing method overriding class, i.e., PyRunner." @@ -95,6 +108,9 @@ def method(*args, **kwargs): class TVMDerivedObject(metadata["cls"]): # type: ignore """The derived object to avoid cyclic dependency.""" + _cls = cls + _type = "TVMDerivedObject" + def __init__(self, *args, **kwargs): """Constructor.""" self.handle = None @@ -111,12 +127,22 @@ def __init__(self, *args, **kwargs): # using weakref to avoid cyclic dependency self._inst._outer = weakref.ref(self) - def __getattr__(self, name: str): - """Bridge the attribute function.""" - try: - return self._inst.__getattribute__(name) - except AttributeError: - return super(TVMDerivedObject, self).__getattr__(name) + def __getattr__(self, name): + # fall back to instance attribute if there is not any + # return self._inst.__getattribute__(name) + import inspect # pylint: disable=import-outside-toplevel + + result = self._inst.__getattribute__(name) + if inspect.ismethod(result): + + def method(*args, **kwargs): + return result(*args, **kwargs) + + # set __own__ to aviod implicit deconstruction + setattr(method, "__own__", self) + return method + + return result def __setattr__(self, name, value): if name not in ["_inst", "key", "handle"]: diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py new file mode 100644 index 0000000000..ae61cec2c3 --- /dev/null +++ b/python/tvm/relax/__init__.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, wrong-import-position +"""The Relax IR namespace containing the IR, type, operator, and builder.""" +from . import exec_builder +from . import expr +from . import ty +from . import vm +from . import block_builder +from . import op +from . import analysis +from . import transform +from . import expr_functor + +# Expr +Expr = expr.Expr +Span = expr.Span +SourceName = expr.SourceName +Id = expr.Id +GlobalVar = expr.GlobalVar +Var = expr.Var +DataflowVar = expr.DataflowVar +Binding = expr.Binding +MatchShape = expr.MatchShape +VarBinding = expr.VarBinding +BindingBlock = expr.BindingBlock +DataflowBlock = expr.DataflowBlock +SeqExpr = expr.SeqExpr +ShapeExpr = expr.ShapeExpr +RuntimeDepShape = expr.RuntimeDepShape +Tuple = expr.Tuple +TupleGetItem = expr.TupleGetItem +Function = expr.Function +ExternFunc = expr.ExternFunc +Call = expr.Call +If = expr.If + +# helper functions +const = expr.const +Constant = expr.Constant +extern = expr.extern +te_tensor = expr.te_tensor + +# Type +Type = ty.Type +ShapeType = ty.ShapeType +ObjectType = ty.ObjectType +DynTensorType = ty.DynTensorType +DimType = ty.DimType +TupleType = ty.TupleType +FuncType = ty.FuncType + +# VM +ExecBuilder = exec_builder.ExecBuilder +VirtualMachine = vm.VirtualMachine + +# Operator +from .op.base import call_tir, make_closure, invoke_closure +from .op.op_attrs import VMAllocStorageAttrs, VMAllocTensorAttrs + +# IRBuilder +BlockBuilder = block_builder.BlockBuilder + +# ExprFunctor +ExprFunctor = expr_functor.ExprFunctor +PyExprVisitor = expr_functor.PyExprVisitor +PyExprMutator = expr_functor.PyExprMutator diff --git a/python/tvm/relax/_ffi_api.py b/python/tvm/relax/_ffi_api.py new file mode 100644 index 0000000000..a127e1c813 --- /dev/null +++ b/python/tvm/relax/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI API for Relax.""" +import tvm._ffi + +tvm._ffi._init_api("relax", __name__) diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py new file mode 100644 index 0000000000..cc0089ff31 --- /dev/null +++ b/python/tvm/relax/analysis/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax IR analysis. """ + +from .analysis import * diff --git a/python/tvm/relax/analysis/_ffi_api.py b/python/tvm/relax/analysis/_ffi_api.py new file mode 100644 index 0000000000..75f119438c --- /dev/null +++ b/python/tvm/relax/analysis/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.analysis""" +import tvm._ffi + +tvm._ffi._init_api("relax.analysis", __name__) diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py new file mode 100644 index 0000000000..e5df317a7f --- /dev/null +++ b/python/tvm/relax/analysis/analysis.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +This file contains the set of passes for Relax, which exposes an interface for +configuring the passes and scripting them in Python. +""" + +from typing import Dict, List + +import tvm +from tvm import tir +from tvm.relax.expr import DataflowBlock, GlobalVar, Var, Expr, Function, Binding +from tvm.ir.module import IRModule +from . import _ffi_api + + +def post_order_visit(expr, fvisit): + """Recursively visit the ir in post DFS order node, + apply fvisit. Each node is guaranteed to be visited + only once. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + fvisit : function + The visitor function to be applied. + """ + return _ffi_api.post_order_visit(expr, fvisit) + + +def well_formed(mod: tvm.IRModule) -> bool: + """Check if the IRModule is well formed. + + Parameters + ---------- + mod : tvm.IRModule + The input IRModule. + + Returns + ------- + ret: bool + True if the IRModule is well formed, False if not. + """ + return _ffi_api.well_formed(mod) + + +def get_var2val(func: Function) -> Dict[Var, Expr]: + """ + Get a mapping from Var to Expr for each variable in the function. + + Parameters + ---------- + func : Function + The input function to be analyzed. + + Returns + ------- + Dict[Var, Expr] + A mapping from Var to Expr. + """ + return _ffi_api.get_var2val(func) + + +def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]: + """ + Analyze the variable use-def chain in a dataflow block. + + Parameters + ---------- + dfb : DataflowBlock + The dataflow block to analyze + + Returns + ------- + Dict[Var, List[Var]] + A mapping from variable definition to its uses. + """ + return _ffi_api.udchain(dfb) + + +def name_to_binding(func: Function) -> Dict[str, List[Binding]]: + """Return a map from variable name to its bindings.""" + return _ffi_api.name_to_binding(func) + + +def remove_all_unused(func: Function) -> Function: + """Remove all unused variables from the function. + + Parameters + ---------- + func : Function + The input function to be analyzed. + + Returns + ------- + Function + The function with unused variables removed. + """ + return _ffi_api.remove_all_unused(func) + + +def shape_vars(expr: Expr) -> List[tir.Var]: + """ + Returns all shape variables (TIR variables) in the given expression. + + Note that the expression is intended to be a shape expression, i.e., + one used as the `shape_` for another expression. + + Parameters + ---------- + expr : Expr + The expression. Meant to be a shape expression. + + Returns + ------- + ret: List[tir.Var] + A list of all shape variables (TIR variables) in the expression. + """ + return _ffi_api.shape_vars(expr) + + +def derive_func_ret_shape(args: List[Var], body: Expr) -> Expr: + """ + Given the argument vars and body, derives a return shape for + a function with those args and that body. + If the body's shape contains free shape vars (those not used in the args), the + return shape is relaxed to RuntimeDepShape; otherwise, the body's shape is used. + + Parameters + ---------- + args: List[Var] + The argument variables, ideally with the shape_ field filled in + + body: Expr + The functino body, ideally with the shape_ field filled in + + Returns + ------- + ret: Expr + An expression that can serve as the return shape for the function + """ + return _ffi_api.derive_func_ret_shape(args, body) + + +def bound_vars(expr: Expr) -> List[Var]: + """ + Return all bound variables from expression expr. + + Bound variables are all variables that are declared in the expr. + They only have meaning inside that expr, and can only be used in it. + + Parameters + ---------- + expr: Expr + The expression. + + Returns + ------- + ret: List[Var] + List of bound vars in expr, in post-DFS order + """ + return _ffi_api.bound_vars(expr) + + +def free_vars(expr: Expr) -> List[Var]: + """ + Return all free variables from expression expr. + + Free variables are variables that are not bound by a + VarBinding or a function parameter in the expression. + + Parameters + ---------- + expr: Expr + The expression. + + Returns + ------- + ret: List[Var] + List of free vars in expr, in post-DFS order + """ + return _ffi_api.free_vars(expr) + + +def all_vars(expr: Expr) -> List[Var]: + """ + Return all (local) variables from expression expr. + + Parameters + ---------- + expr: Expr + The expression. + + Returns + ------- + ret: List[Var] + List of vars in expr, in post-DFS order + """ + return _ffi_api.all_vars(expr) + + +def all_global_vars(expr: Expr) -> List[GlobalVar]: + """ + Return all global variables from expression expr. + + Parameters + ---------- + expr: Expr + The expression. + + Returns + ------- + ret: List[GlobalVar] + List of global vars in expr, in post-DFS order + """ + return _ffi_api.all_global_vars(expr) + + +def called_global_vars(expr: Expr) -> List[GlobalVar]: + """ + Return all global vars called (potentially recursively) from expr. + + Parameters + ---------- + expr: Expr + The expression + + Returns + ------- + ret: List[GlobalVar] + List of global vars that are used recursively in expr, + in post-DFS order + """ + return _ffi_api.called_global_vars(expr) + + +def extract_buffer_info(main_func: Function, mod: IRModule): + """This analysis pass consumes an IRModule with a Relax main function + that defines a ordering in the callees to operators and produces BufferInfo + objects that contains information about tensor allocations and liveness + conflicts between allocations. + + Parameters + ---------- + main_func: tvm.relax.Function + The main function containing calls to operator PrimFuncs. + mod : tvm.ir.IRModule + The full IRModule containing all Relax and Prim functions. + + Returns + ------- + Map + extracted buffer info objects + """ + return _ffi_api.extract_buffer_info(main_func, mod) diff --git a/python/tvm/relax/binding_rewrite.py b/python/tvm/relax/binding_rewrite.py new file mode 100644 index 0000000000..3f0c8d5cb3 --- /dev/null +++ b/python/tvm/relax/binding_rewrite.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name +"""Developer API of add/remove/replace bindings in Relax.""" + +from typing import Optional + +import tvm +import tvm._ffi +from tvm.runtime import Object +from . import Binding, DataflowBlock, Expr, Function, Var +from . import _ffi_api + + +@tvm._ffi.register_object("relax.DataflowBlockRewrite") +class DataflowBlockRewrite(Object): + """ + A binding/statement-level dataflow block rewriter. + + Notes + ----- + Due to the immutable and copy-on-write nature of TVM AST nodes, the rewriting is not done in + place. Instead, a new DataflowBlock is created and returned with mutated_dfb. Similarly, its new + root Function is created and returned by mutated_root_fn. To apply this change for an IRModule, + use mutate_irmodule which rewrites the old function that registered in the constructor. + """ + + def __init__(self, dfb: DataflowBlock, root_fn: Function): + """ + Construct a rewriter with the DataflowBlock to rewrite and its root function. + + Parameters + ---------- + dfb : DataflowBlock + The DataflowBlock to rewrite. + root_fn : Function + The root function of the DataflowBlock. + """ + self.func_name = root_fn.__name__ if hasattr(root_fn, "__name__") else None + self.__init_handle_by_constructor__(_ffi_api.DataflowBlockRewrite, dfb, root_fn) + + def replace_all_uses(self, old_var: Var, new_var: Var) -> None: + """ + Replace all uses of old_var with new_var. + + Parameters + ---------- + old_var : Var + The old variable to replace. + new_var : Var + The new variable to replace with. + """ + _ffi_api.dfb_rewrite_replace_all_uses(self, old_var, new_var) + + def add_binding(self, binding: Binding) -> None: + return _ffi_api.dfb_rewrite_add_binding(self, binding) + + def add(self, expr: Expr, name: Optional[str] = None, is_dfvar: bool = False) -> None: + """ + Add a new statement to the DataflowBlock with an automatically generated variable name. + + Parameters + ---------- + expr : Expr + The expression to add. + name : Optional[str], optional + Variable name, by default None + is_dfvar : bool, optional + The variable type, by default False + + Notes + ----- + If the variable name is not given, it will be automatically generated in a form of + "tmp${COUNTER}". The variable type will be DataflowVar if is_dfvar is True, otherwise + it will be Var. Being Var means the variables are output variables of the DataflowBlock. + While being DataflowVar means the variables are internal variables of the DataflowBlock. + """ + _ffi_api.dfb_rewrite_add(self, expr, name, is_dfvar) + + def remove_unused(self, var: Var, allow_undef=False) -> None: + """ + Remove a statement by its variable definition if and only if it is unused. + + Parameters + ---------- + var : Var + The unused variable definition. + allow_undef : bool, optional + Whether to allow var being undefined variable, by default False + + Raises + ------ + TVMError if the variable is used or undefined (allow_undef=False). + """ + _ffi_api.dfb_rewrite_remove_unused(self, var, allow_undef) + + def remove_all_unused(self) -> None: + """ + Remove all unused variables. + + Notes + ----- + This could remove unused variables in other DataflowBlocks as well. + """ + _ffi_api.dfb_rewrite_remove_all_unused(self) + + def mutated_dfb(self) -> DataflowBlock: + """ + Returns the mutated DataflowBlock. + """ + return self.dfb + + def mutated_root_fn(self) -> Function: + """ + Returns the mutated root function. + """ + ret = self.root_fn + if self.func_name: + ret.__name__ = self.func_name + return ret + + def mutate_irmodule(self, irmodule: tvm.IRModule) -> tvm.IRModule: + """ + Return an updated IRModule by replacing the old function with the mutated root function. + + Parameters + ---------- + irmodule : tvm.IRModule + The base IRModule to update. + + Returns + ------- + tvm.IRModule + The updated IRModule. + """ + ret = _ffi_api.dfb_rewrite_mutate_irmodule(self, irmodule) + if hasattr(irmodule, "__name__"): + ret.__name__ = irmodule.__name__ + return ret diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py new file mode 100644 index 0000000000..154b7897e5 --- /dev/null +++ b/python/tvm/relax/block_builder.py @@ -0,0 +1,757 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name +"""Developer API of constructing Relax AST.""" +import typing + +from typing import Dict, List, Optional, Union, Any, Callable +from tvm.ir.module import IRModule +from tvm.runtime import Object +from tvm import relax as rx, tir +import tvm +from .expr import ( + Expr, + te_tensor, + Var, + ShapeExpr, + GlobalVar, + PrimExpr, + BindingBlock, + Tuple, + BaseFunc, + VarBinding, + MatchShape, +) +from .op.base import call_tir +from . import _ffi_api + + +class FunctionScope(object): + """Auxiliary scope for function""" + + def __init__(self, block_builder, name, params, attrs): + self._bb = block_builder + self._name = name + self._params = params + self._attrs = attrs + + def __enter__(self): + self._bb._enter_function_scope(self._name, self._params, self._attrs) + + def __exit__(self, exc_type, exc_val, exc_tb): + # __exit__ should properly handle the case where the with block exits with an exception + # when handling error case in exit, always check if there is already an exception + # been thrown in the with block + self._bb._exit_function_scope(exc_type, exc_val, exc_tb) + + +class DataflowScope(object): + """Auxiliary scope for Dataflow block""" + + def __init__(self, block_builder): + self._bb = block_builder + + def __enter__(self): + block = self._bb._end_block() + if len(block.bindings) > 0: + self._bb._blocks.append(block) + self._bb._begin_dataflow_block() + + def __exit__(self, ptype, value, trace): + block = self._bb._end_block() + if len(block.bindings) > 0: + self._bb._blocks.append(block) + self._bb._begin_binding_block() + + +@tvm._ffi.register_object("relax.BlockBuilder") +class BlockBuilder(Object): + """A builder to build Relax IR for testing and dev. + + Examples + -------- + .. code-block:: python + + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + bb = rx.BlockBuilder() + with bb.function([x, y], "func"): + with bb.dataflow() as df: + lv0 = bb.emit(rx.add(x, y)) + lv1 = bb.emit(rx.multiply(lv0, y)) + gv0 = bb.emit_output(lv1) + bb.emit_func_output(gv0) + mod = bb.get() + + BlockBuilder can also be used to contruct neural networks with nn.Module API + + .. code-block:: python + + from tvm.relax.testing import nn + + n = tir.Var("n", "int64") + input_size = 784 + hidden_sizes = [128, 32] + output_size = 10 + bb = rx.BlockBuilder() + + with bb.function("main"): + model = nn.Sequential( + nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LogSoftmax(), + ) + data = nn.Placeholder((n, input_size), name="data") + output = model(data) + params = [data] + model.parameters() + builder.emit_func_output(output, params=params) + mod = bb.get() + """ + + _current = None + + @staticmethod + def current(): + """Returns the current BlockBuilder.""" + return BlockBuilder._current + + def __init__(self, mod: IRModule = None): + self._blocks = [] + # a boolean flag that tracks if emit_func_output has been called + self._is_emit_func_output_called = False + self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate, mod) + + def _begin_dataflow_block(self) -> None: + _ffi_api.BlockBuilderBeginDataflowBlock(self) + + def _begin_binding_block(self) -> None: + _ffi_api.BlockBuilderBeginBindingBlock(self) + + def _end_block(self) -> BindingBlock: + return _ffi_api.BlockBuilderEndBlock(self) + + def _enter_function_scope(self, name, params, attrs): + if BlockBuilder.current() is not None: + raise RuntimeError("BlockBuilder does not allow nested functions.") + BlockBuilder._current = self + self._func_name = name + self._func_params = params + self._func_attrs = attrs + self._begin_binding_block() + + def _exit_function_scope(self, exc_type, exc_val, exc_tb): + # record + is_emit_func_output_called = self._is_emit_func_output_called + # recover to default state + self._blocks = [] + self._is_emit_func_output_called = False + BlockBuilder._current = None + + # NOTE: we must raise after we recover the state so future + # block builder scoping functions correctly + if exc_type is None: + if not is_emit_func_output_called: + raise RuntimeError("emit_func_output must be called in a relax function.") + + def _convert_te_arg(self, te_args: Any) -> typing.Tuple[Any, List[tvm.te.Tensor]]: + """Helper function to convert Relax expressions to te tensor. + In the common case, the type of te_args is a Relax expression and is converted + into a te tensor. + If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array), + we recursive and convert any value of type Relax expression into a te tensor. + Common values of type int, float, and str are preserved. + + Parameters + ---------- + te_args : Any + Argument to convert to te + + Returns + ------- + ret : (Any, [tvm.te.Tensor]) + A tuple of the converted te_args, and a list of te tensors for each converted + Relax expression + """ + te_args_list = [] + + def _convert_te_arg_helper(arg): + if isinstance(arg, Expr): + arg = te_tensor(arg) + te_args_list.append(arg) + return arg + elif isinstance(arg, (list, tvm.ir.Array)): + return [_convert_te_arg_helper(x) for x in arg] + elif isinstance(arg, tuple): + return tuple([_convert_te_arg_helper(x) for x in arg]) + elif isinstance(arg, (dict, tvm.ir.Map)): + for key in arg: + assert isinstance( + key, str + ), "emit_te only supports dict with string as the key currently" + return {k: _convert_te_arg_helper(arg[k]) for k in arg} + elif ( + isinstance(arg, (int, float, str, tir.IntImm, tvm.ir.Type, tvm.ir.Attrs)) + or arg is None + ): + return arg + raise TypeError("not supported type in emit_te: {}".format(type(arg))) + + new_arg = _convert_te_arg_helper(te_args) + return new_arg, te_args_list + + def _get_unbound_tir_vars(self, args: List[tvm.te.Tensor]) -> List[tvm.tir.Var]: + """get unbound TIR vars (i.e TIR vars used in the shape but is not + itself a dimension of a shape)""" + bound_vars = set() + used_vars = set() + + def _populate_used_vars(expr): + if isinstance(expr, tvm.tir.Var): + used_vars.add(expr) + + for x in args: + for s in x.shape: + tvm.tir.stmt_functor.post_order_visit(s, _populate_used_vars) + if isinstance(s, tir.Var): + bound_vars.add(s) + + diff = used_vars - bound_vars + return list(diff) + + def function( + self, + name: str, + params: Optional[Union[Var, Tuple, List[Var]]] = None, + attrs: Optional[Dict[str, Object]] = None, + ) -> FunctionScope: + """Annotate a Relax function. + + Parameters + ---------- + name : str, optional + The name of the function + + params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional + The parameters of the function. + If params is None, it means deferring initialization of function parameters + until emit_func_output. + + attrs : Dict[str, Object], optional + The function attrs + + Returns + ------- + ret: FunctionScope + A FunctionScope for building a Relax function node. + """ + if not params: + params = None + elif isinstance(params, rx.Var): + params = [params] + elif isinstance(params, (list, tuple)): + for param in params: + if not isinstance(param, rx.Var): + raise TypeError( + "each element of function parameters must be of type tvm.relax.Var,\ + but got: {}".format( + type(param) + ) + ) + if attrs is None: + attrs = {} + return FunctionScope(self, name, params, attrs) + + def dataflow(self) -> DataflowScope: + """Annotate a Relax dataflow block. + + Returns + ------- + ret: DataflowScope + A DataflowScope for building a Relax dataflow block. + """ + return DataflowScope(self) + + def emit(self, expr: Expr) -> Var: + """Emit an expr. + This infers the shape and type of the expr, create a variable, + and bind the expr to the variable. + + Parameters + ---------- + expr : tvm.relax.Expr + The Expr to be emitted. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets bound to the input expr. + """ + return _ffi_api.BlockBuilderEmit(self, expr) + + def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr: + """Generate a call node according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor or a list of te tensors. + Please see detailed example in emit_te + + Parameters + ---------- + func : Callable + A function that returns a te tensor or a list of te tensors. + + args : Any, optional + arguments passed to the function. + + kwargs : Any, optional + The keyword arguments passed to the function. + Note that the key "primfunc_name_hint" is reserved for passing name hint + to the PrimFunc that gets generated. + + Returns + ------- + ret : tvm.relax.Call + A newly created call node + """ + + primfunc_name_hint = kwargs.pop("primfunc_name_hint", None) + new_args, te_arg_list = self._convert_te_arg(args) + new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs) + + te_args = te_arg_list + te_kwarg_list + + te_out = func(*new_args, **new_kwargs) + assert isinstance(te_out, tvm.te.tensor.Tensor) or ( + isinstance(te_out, (tuple, list, tvm.ir.Array)) + and all(isinstance(t, tvm.te.tensor.Tensor) for t in te_out) + ), "only support te.tensor or tuple/list/Array of te.tensor as function output" + + if isinstance(te_out, (tuple, list, tvm.ir.Array)) and len(te_out) == 1: + te_out = te_out[0] + + outs = [te_out] if isinstance(te_out, tvm.te.tensor.Tensor) else list(te_out) + unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs) + + inputs = [*te_args] + outs + tir_func = tvm.te.create_prim_func(inputs, unbound_tir_vars) + + if primfunc_name_hint: + gvar = self.add_func(tir_func, primfunc_name_hint) + else: + gvar = self.add_func(tir_func, func.__name__) + + call_args = [x.op.value for x in te_args] + + output_shape = ( + outs[0].shape + if isinstance(te_out, tvm.te.tensor.Tensor) + else Tuple([ShapeExpr(x.shape) for x in outs]) + ) + + output_dtype = ( + te_out.dtype if isinstance(te_out, tvm.te.tensor.Tensor) else [x.dtype for x in outs] + ) + + # add arguments for extra parameters from unbound var + if len(unbound_tir_vars) > 0: + call = call_tir( + gvar, call_args, output_shape, output_dtype, tir_vars=ShapeExpr(unbound_tir_vars) + ) + else: + call = call_tir(gvar, call_args, output_shape, output_dtype) + return call + + def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var: + """Emit a call node according to the te function. + This function converts arguments from relax expression to te tensor, + The callback func should return a te tensor or a list of te tensors. + + Parameters + ---------- + func : Callable + A function that returns a te tensor or a list of te tensors. + + args : Any, optional + arguments passed to the function. + + kwargs : Any, optional + The keyword arguments passed to the function. + Note that the key "primfunc_name_hint" is reserved for passing name hint + to the PrimFunc that gets generated. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets bound to the call code. + + Example + ------- + + .. code-block:: python + + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", [n, m], type_anno) + y = rx.Var("y", [n, m], type_anno) + + def te_func(args, args_dict, msg): + A = args[0] + B = args_dict["B"] + return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + + with bb.function([x, y], "rx_func"): + out = bb.emit_te(te_func, [x], {"B": y}, msg="hello") + bb.emit_func_output(out) + + will result in TVMScript + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @T.prim_func + def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, + var_compute: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "te_func"}) + m = T.var("int64") + n = T.var("int64") + rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32") + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32") + compute = T.match_buffer(var_compute, [128, 128], dtype="float32") + # body + # with T.block("root") + for i0, i1 in T.grid(128, 128): + with T.block("compute"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads([rxplaceholder[i, j], rxplaceholder_1[i, j]]) + T.writes([compute[i, j]]) + compute[i, j] = rxplaceholder[i, j] + rxplaceholder_1[i, j] + + @R.function + def rx_func(x: Tensor((n, m), "float32"), y: Tensor((n, m), "float32")) -> Tensor: + # block 0 + gv = relax.call_tir("te_func", (x, y), (128, 128), dtype="float32") + return gv + + Example + ------- + + .. code-block:: python + + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + type_anno = relax.DynTensorType(1, "float32") + x = relax.Var("x", [n], type_anno) + y = relax.Var("y", [n + 1], type_anno) + + def te_func(A): + C = te.compute((n + 1), lambda i: A[i]) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, y) + bb.emit_func_output(x1) + + will result in TVMScript + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @T.prim_func + def te_func(var_rxplaceholder: T.handle, var_compute: T.handle, n: T.int64) -> None: + # function attr dict + T.func_attr({"global_symbol": "te_func"}) + rxplaceholder = T.match_buffer(var_rxplaceholder, [n + T.int64(1)], + dtype="float32") + compute = T.match_buffer(var_compute, [n + T.int64(1)], dtype="float32") + # body + # with T.block("root") + for i0 in T.serial(0, n + T.int64(1)): + with T.block("compute"): + i = T.axis.spatial(n + T.int64(1), i0) + T.reads([rxplaceholder[i]]) + T.writes([compute[i]]) + compute[i] = rxplaceholder[i] + + @R.function + def rx_func(x: Tensor((n,), "float32"), y: Tensor(((n + 1),), "float32")) + -> Tensor(None, "float32", ndim=-1): + # block 0 + gv = relax.call_tir(te_func, (y,), ((n + 1),), (n,), dtype="float32") + return gv + """ + return self.emit(self.call_te(func, *args, **kwargs)) + + def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var: + """Emit a MatchShape. + + Parameters + ---------- + value : tvm.relax.Expr + The value of the MatchShape to be emitted. + + pattern : List[PrimExpr] + The pattern of the MatchShape to be emitted. + + Returns + ------- + ret : tvm.relax.Var + A newly created variable that gets bound to the call code. + """ + return _ffi_api.BlockBuilderEmitMatchShape(self, value, pattern) + + def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: + """Emit output for the current dataflow block or function. + + Parameters + ---------- + output : Expr | Tuple | List[Expr] + The output of the current block/function. + + Returns + ------- + ret : tvm.relax.Var + The return variable which gets bound to the output. + """ + if isinstance(output, (list, tuple)): + output = Tuple(output) + return _ffi_api.BlockBuilderEmitOutput(self, output) + + def emit_func_output( + self, + output: Union[Expr, Tuple, List[Expr]], + params: Optional[Union[Var, Tuple, List[Var]]] = None, + ) -> None: + """Emit output for the function. + + Parameters + ---------- + output : Expr | Tuple | List[Expr] + The output of the current block/function. + + params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional + The parameters of the function to be built. + If params is None, it means the params have been initialized in the function with scope. + + Returns + ------- + ret : tvm.relax.Var + The return variable which gets bound to the output. + """ + if self._is_emit_func_output_called: + raise RuntimeError("emit_func_output must be called exactly once in a relax function.") + self._is_emit_func_output_called = True + + if self._func_params is not None and params is not None: + raise RuntimeError( + "function parameters have been initialized in the function with scope." + ) + + if self._func_params is None and params is None: + raise RuntimeError("Relax function must have parameter.") + + if self._func_params is None: + self._func_params = params + + if BlockBuilder.current() is not self: + raise RuntimeError("BlockBuilder._current must be self.") + + if isinstance(output, (list, tuple)): + output = Tuple(output) + self._func_ret = self.normalize(output) + + block = self._end_block() + if len(block.bindings) > 0: + self._blocks.append(block) + seqe = self.normalize(rx.SeqExpr(self._blocks, self._func_ret)) + + # The function's checked_type_ relies on the function body(seqe) to have deduced type + # TODO(@yuchen): handle the case where the body's checked_type_ is null + # TODO: Deduce the ret shape too + func = rx.Function(self._func_params, seqe, None, rx.RuntimeDepShape()) + func = func.with_attr("global_symbol", self._func_name) + for key, value in self._func_attrs.items(): + func = func.with_attr(key, value) + self.add_func(func, self._func_name) + + def normalize(self, expr: Expr) -> Expr: + """Normalize an Expr to complete its shape and type. + + Parameters + ---------- + expr : Expr + The input expr. + + Returns + ------- + ret : Expr + The expr with normalized shape and type. + """ + return _ffi_api.BlockBuilderNormalize(self, expr) + + def get(self) -> tvm.IRModule: + """Return the IRModule being built. + + Returns + ------- + ret : tvm.IRModule + An IRModule with Relax and TIR functions being built. + """ + return _ffi_api.BlockBuilderGetContextIRModule(self) + + def get_unique_name(self, name_prefix: str) -> str: + """Generate a unique name with a specified prefix. + + Parameters + ---------- + name_hint : str + The name prefix. + + Returns + ------- + ret : str + The generated name. + """ + return _ffi_api.BlockBuilderGetUniqueName(self, name_prefix) + + def add_func(self, func: BaseFunc, func_name: str) -> GlobalVar: + """Add a Relax function or a TIR PrimFunc to the IRModule being built. + + Parameters + ---------- + func : BaseFunc + The function to be added. + + func_name : str + The name of the function to be added. + + Returns + ------- + gvar : GlobalVar + The global var bound to the added function. + """ + return _ffi_api.BlockBuilderAddFunction(self, func, func_name) + + def update_func(self, gv: GlobalVar, updated_func: BaseFunc) -> None: + """Add a Relax function or a TIR PrimFunc to the IRModule being built. + + Parameters + ---------- + gv : GlobalVar + The global var referring the function to be updated. + + updated_func : BaseFunc + The updated function. + """ + return _ffi_api.BlockBuilderUpdateFunction(self, gv, updated_func) + + def can_prove_shape_equal(self, lhs: Expr, rhs: Expr) -> bool: + """Check if two shape expressions can be proven equal at compile time. + + Parameters + ---------- + lhs : Expr + The input lhs shape. + + rhs: Expr + The input rhs shape. + + Returns + ------- + ret : bool + Whether we can prove lhs shape is the same as the rhs shape. + """ + return _ffi_api.BlockBuilderCanProveShapeEqual(self, lhs, rhs) + + def current_block_is_dataflow(self) -> bool: + """Check if the block being built is DataflowBlock or not. + + Returns + ------- + ret : bool + A boolean that indicates if the block being built is DataflowBlock or not. + """ + return _ffi_api.BlockBuilderCurrentBlockIsDataFlow(self) + + def emit_var_binding(self, binding: VarBinding) -> Var: + """Emits a variable binding, and returns the bound Var. + + Parameters + ---------- + binding: VarBinding + The variable binding. + + Returns + ------- + var: Var + The bound variable. + """ + return _ffi_api.BlockBuilderEmitVarBinding(self, binding) + + def emit_output_var_binding(self, binding: VarBinding) -> Var: + """Generate an output for the current dataflow block. + + Parameters + ---------- + binding: VarBinding + The output binding to output. + + Returns + ------- + var: Var + The variable bound to output. + """ + return _ffi_api.BlockBuilderEmitOutputVarBinding(self, binding) + + def match_shape_binding(self, binding: MatchShape) -> Var: + """Emit a MatchShape binding. + + Parameters + ---------- + binding: MatchShape + The MatchShape binding to be emitted. + + Returns + ------- + var: Var + The variable bound to the MatchShape. + """ + return _ffi_api.BlockBuilderEmitMatchShapeBinding(self, binding) + + def lookup_binding(self, var: Var) -> Optional[Expr]: + """Lookup a var in the binding table binding_table_. + + Parameters + ---------- + var: Var + The input var. + + Returns + ------- + expr: Expr + The Expr bound to the input var. + """ + return _ffi_api.BlockBuilderLookupBinding(self, var) diff --git a/python/tvm/relax/dpl/__init__.py b/python/tvm/relax/dpl/__init__.py new file mode 100644 index 0000000000..e0bbdaff05 --- /dev/null +++ b/python/tvm/relax/dpl/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""The Relax Dataflow Pattern Language.""" + +from .pattern import * +from .context import * diff --git a/python/tvm/relax/dpl/_ffi.py b/python/tvm/relax/dpl/_ffi.py new file mode 100644 index 0000000000..6699e42bee --- /dev/null +++ b/python/tvm/relax/dpl/_ffi.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DataFlow Pattern Language FFI bindings.""" +import tvm._ffi + +tvm._ffi._init_api("relax.dpl", __name__) diff --git a/python/tvm/relax/dpl/context.py b/python/tvm/relax/dpl/context.py new file mode 100644 index 0000000000..222d0ce2d0 --- /dev/null +++ b/python/tvm/relax/dpl/context.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""The Graph Matching Context Manager for Dataflow Pattern Language.""" + +from typing import Optional, Dict + +import tvm +from tvm.relax import DataflowBlock, Var +from .pattern import DFPattern +from . import _ffi as ffi + + +class PatternContext(tvm.runtime.Object): + """A context object for doing graph (topogical) pattern matching.""" + + def __init__(self, incremental=False): + """ + Initialize the PatternContext + + Parameters + ---------- + incremental : bool, optional + perform incremental matching based on the recent context, by default False + """ + self.__init_handle_by_constructor__(ffi.PatternContext, incremental) + + def __enter__(self): + """Enter the context""" + ffi.enter_context(self) + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Exit the context""" + ffi.exit_context(self) + + @staticmethod + def current() -> "PatternContext": + """ + Get the current context + + Returns + ------- + PatternContext + The current context + """ + return ffi.current_context() + + def match_dfb( + self, + dfb: DataflowBlock, + start_hint: Optional[Var] = None, + must_include_hint: bool = False, + ) -> Dict[DFPattern, Var]: + """ + Match a DataflowBlock via a graph of DFPattern and corresponding constraints + + Parameters + ---------- + dfb : DataflowBlock + The DataflowBlock to match + start_hint : Optional[Var], optional + Indicating the starting expression to match, by default None + must_include_hint : bool, optional + Whether the start_hint expression must be matched, by default False + + Returns + ------- + Dict[DFPattern, Var] + The mapping from DFPattern to matched expression + """ + return ffi.match_dfb(self, dfb, start_hint, must_include_hint) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py new file mode 100644 index 0000000000..913d148115 --- /dev/null +++ b/python/tvm/relax/dpl/pattern.py @@ -0,0 +1,1038 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern types in Relax Dataflow Pattern Language""" +# pylint: disable=no-member +# pylint: disable=pointless-statement + +from typing import List, Optional, Dict, Union, Tuple + +import tvm +import tvm._ffi as tvm_ffi +from tvm.ir.expr import PrimExpr +from tvm.relax import Expr, Var +from tvm.relay.op import get +from tvm.ir.container import Array + +from ...ir import make_node +from ...runtime import Object +from ...ir.base import Node +from . import _ffi as ffi + + +def register_df_node(type_key=None): + """ + Register a Relax node type + + Parameters + ---------- + type_key : str or cls + The type key of the node + """ + if not isinstance(type_key, str): + return tvm_ffi.register_object("relax.dpl." + type_key.__name__)(type_key) + return tvm_ffi.register_object(type_key) + + +class DFPattern(Node): + """Base class of all Patterns.""" + + def __call__(self, *args, varg_default_wildcard=False) -> "CallPattern": + """ + Syntax sugar for creating a CallPattern with argument patterns + + Returns + ------- + result: CallPattern + The resulting CallPattern + """ + return CallPattern(self, args, varg_default_wildcard) + + def __or__(self, other: "DFPattern") -> "OrPattern": + """ + Syntax sugar for creating an OrPattern + + Parameters + ---------- + other: DFPattern + Alternative pattern + + Returns + ------- + result: OrPattern + The resulting OrPattern + """ + return OrPattern(self, other) + + def __and__(self, other: "DFPattern") -> "AndPattern": + """ + Syntax sugar for creating an AndPattern + + Parameters + ---------- + other: DFPattern + Additional pattern to satisfy + + Returns + ------- + result: AndPattern + The resulting AndPattern + """ + return AndPattern(self, other) + + def __invert__(self) -> "NotPattern": + """ + Syntax sugar for creating a DFPattern to reject + + Returns + ------- + result: NotPattern + The resulting NotPattern + """ + return reject(self) + + def has_attr(self, attrs: Dict[str, Object]) -> "AttrPattern": + """ + Add an attribute constraint to this pattern + + Parameters + ---------- + attrs: Dict[str, Object] + + Returns + ------- + result: AttrPattern + The resulting AttrPattern + """ + attrs = make_node("DictAttrs", **attrs) + return AttrPattern(self, attrs) + + def has_type(self, ttype: tvm.ir.type.Type) -> "TypePattern": + """ + Add a type constraint to this pattern + + Parameters + ---------- + ttype: tvm.ir.type.Type + The type to match + + Returns + ------- + result: TypePattern + The resulting TypePattern + """ + return TypePattern(self, ttype) + + def has_dtype(self, dtype: str) -> "DataTypePattern": + """ + Add a type constraint to this pattern + + Parameters + ---------- + dtype: str + The dtype to match + + Returns + ------- + result: DataTypePattern + The resulting DataTypePattern + """ + return has_dtype(dtype, self) + + def has_shape(self, shape: List[PrimExpr]) -> "ShapePattern": + """ + Add a shape constraint to this pattern + + Parameters + ---------- + shape: List[PrimExpr] + Expected shape list + + Returns + ------- + result: ShapePattern + The resulting ShapePattern + + Note + ---- + has_shape assumes that the matched relax.Expr only has one + output tensor. Use is_tuple for those with multiple outputs. + """ + if not isinstance(shape, (list, tuple, tvm.ir.PrimExpr)): + raise ValueError("has_shape takes a list or tuple as input.") + return ShapePattern(pattern=self, shape=shape) + + def match(self, expr, var2val: Optional[Dict[Var, Expr]] = None) -> bool: + """ + Match a relax.Expr syntactically + + Parameters + ---------- + expr : tvm.relax.Expr + The expression to match + var2val : Optional[Dict[tvm.relax.Var, tvm.relax.Expr]] + A mapping from relax.Var to relax.Expr for autojump. + + Returns + ------- + result: bool + Whether or not the expression matches the pattern + + Note + ---- + Unlike Relay whose function is an expression, functions in Relax consists + of blocks of bindings that they are not syntactically connected. We use a + mapping (i.e., var2val) to migrate the gap. For example, to when matching + "relax.add(lv0, lv1)", given var2val, we match lv0's binded expression + when the recursive pattern matching goes to check lv0. The var2val mapping + can be computed through the tvm.relax.analysis.get_var2val function. + """ + return ffi.match_expr(self, expr, var2val) + + def has_rt_dep_shape(self) -> "AndPattern": + """ + Syntax sugar for assuming current node has a runtime-dependent shape + + Returns + ------- + result: AndPattern + The resulting AndPattern + """ + return RuntimeDepShapePattern(self) + + def used_by(self, other: Union["DFPattern", "PatternSeq"], index=-1) -> "PatternSeq": + """ + The current pattern being used by another pattern (sequence) + + Parameters + ---------- + other : Union[DFPattern, DFPattern] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + result: PatternSeq + A chained pattern sequence + """ + return _used_by(self, other, index) + + def __xor__(self, other: Union["DFPattern", "PatternSeq"]) -> "PatternSeq": + """Syntax sugar of DFPattern.used_by""" + return self.used_by(other, -1) + + def only_used_by(self, other: Union["DFPattern", "PatternSeq"], index=-1) -> "PatternSeq": + """ + The current pattern being **ONLY** used by another pattern (sequence) + + Parameters + ---------- + other : Union[DFPattern, DFPattern] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + result: PatternSeq + A chained pattern sequence + """ + return _only_used_by(self, other, index) + + def __rshift__(self, other: Union["DFPattern", "PatternSeq"]) -> "PatternSeq": + """Syntax sugar of DFPattern.only_used_by""" + return self.only_used_by(other, -1) + + def dup(self) -> "DFPattern": + """ + Duplicate the current pattern (new object under different address) + + Returns + ------- + DFPattern + A duplicated pattern + """ + return ffi.dup_pattern(self) + + def fork_to(self, *args) -> None: + """Fork the current pattern to multiple pattern branches""" + for v in args: + self ^ v + + +@register_df_node +class RuntimeDepShapePattern(DFPattern): + """A pattern matching a Relax RuntimeDepShape.""" + + def __init__(self, pattern: DFPattern): + self.__init_handle_by_constructor__(ffi.RuntimeDepShapePattern, pattern) + + +@register_df_node +class ExprPattern(DFPattern): + """A pattern which matches an expression. + + Parameters + ---------- + expr : tvm.relax.Expr + The expression to match. + """ + + def __init__(self, expr: Expr): + self.__init_handle_by_constructor__(ffi.ExprPattern, expr) + + +@register_df_node +class VarPattern(DFPattern): + """A pattern for Var. + + Parameters + ---------- + name_hint: str + The name of the variable. Optional, if not provided, + the pattern will match any VarNode. + """ + + def __init__(self, name_hint: str = ""): + self.__init_handle_by_constructor__(ffi.VarPattern, name_hint) + + +@register_df_node +class DataflowVarPattern(DFPattern): + """A pattern for DataflowVar. + + Parameters + ---------- + name_hint: str + The name of the variable. Optional, if not provided, + the pattern will match any VarNode. + """ + + def __init__(self, name_hint: str = ""): + self.__init_handle_by_constructor__(ffi.DataflowVarPattern, name_hint) + + +@register_df_node +class GlobalVarPattern(DFPattern): + """A pattern for GlobalVar. + + Parameters + ---------- + name_hint: str + The name of the variable. Optional, if not provided, + the pattern will match any GlobalVarNode. + """ + + def __init__(self, name_hint: str = ""): + self.__init_handle_by_constructor__(ffi.GlobalVarPattern, name_hint) + + +@register_df_node +class ExternFuncPattern(DFPattern): + """A external function pattern. + + Parameters + ---------- + global_symbol: str + The name of the function. Optional, if not provided, + the pattern will match any ExternFuncNode. + """ + + def __init__(self, global_symbol: str = ""): + self.__init_handle_by_constructor__(ffi.ExternFuncPattern, global_symbol) + + +@register_df_node +class ConstantPattern(DFPattern): + """A pattern matching a Relax Constant.""" + + def __init__(self): + self.__init_handle_by_constructor__(ffi.ConstantPattern) + + +@register_df_node +class CallPattern(DFPattern): + """A pattern matching a function call node. + + Parameters + ---------- + op: tvm.relax.dpl.DFPattern + The operation to be called. + + args: List[tvm.relax.dpl.DFPattern] + The arguments to the call or None to match any arguments. + + varg_default_wildcard: bool + If True, args can be fewer than actual provided arguments. + + Note + ---- + By setting varg_default_wildcard to True, we can only focus on the argument + patterns we specified. For example, CallPattern(Op, [A, B]) can match + a call of Op(A, B) or Op(A, B, C, ...) that has more arguments. However, + the specified argument patterns must be matched (i.e., A and B). + """ + + def __init__( + self, + op: "DFPattern", + args: List["DFPattern"], + varg_default_wildcard: bool = False, + ): + self.__init_handle_by_constructor__(ffi.CallPattern, op, args, varg_default_wildcard) + + +@register_df_node +class FunctionPattern(DFPattern): + """A pattern matching a function node in Relax. + + Parameters + ---------- + params: List[tvm.relax.dpl.DFPattern] + The parameters to the Function or None to match any parameters. + + body: tvm.relax.dpl.DFPattern + The body fo the Function + + """ + + def __init__( + self, + params: List["DFPattern"], + body: "DFPattern", + ): + self.__init_handle_by_constructor__(ffi.FunctionPattern, params, body) + + +@register_df_node +class TuplePattern(DFPattern): + """A patern matching a Relax Tuple. + + Parameters + ---------- + fields : Array[tvm.relax.dpl.DFPattern] + The fields in the tuple. + """ + + def __init__(self, fields: Array): + self.__init_handle_by_constructor__(ffi.TuplePattern, fields) + + def __getitem__(self, index: Optional[int]) -> "TupleGetItemPattern": + if index is not None: + # support negative index for being pythonic + if index < 0: + index += len(self) + if index >= len(self): + raise IndexError("TuplePattern index out of range") + else: + index = -1 # -1 means matching any index + return TupleGetItemPattern(self, index) + + def __len__(self): + return len(self.fields) + + +@register_df_node +class UnorderedTuplePattern(DFPattern): + """A patern matching a Relax Tuple unorderedly. + + Parameters + ---------- + fields : Array[tvm.relax.dpl.DFPattern] + The fields in the tuple. + """ + + def __init__(self, fields: Array): + self.__init_handle_by_constructor__(ffi.UnorderedTuplePattern, fields) + + def __len__(self): + return len(self.fields) + + +@register_df_node +class TupleGetItemPattern(DFPattern): + """Get index-th item from a TuplePattern. + + Parameters + ---------- + tuple_value: tvm.relax.dpl.DFPattern + The input tuple expression. + + index: Optional[int] + The index to match; Default (None) to match a TupleGetItem with any index. + """ + + def __init__(self, tuple_value: "DFPattern", index: Optional[int] = None): + match_index = index if index is not None else -1 + self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, match_index) + + +@register_df_node +class OrPattern(DFPattern): + """Create a Pattern that can match one of two conditions + + Parameters + ---------- + left: tvm.relax.dpl.DFPattern + One possible matching pattern. + right: tvm.relax.dpl.DFPattern + One possible matching pattern. + """ + + def __init__(self, left: "DFPattern", right: "DFPattern"): + self.__init_handle_by_constructor__(ffi.OrPattern, left, right) + + +@register_df_node +class AndPattern(DFPattern): + """Create a Pattern that must match two conditions + + Parameters + ---------- + left: tvm.relax.dpl.DFPattern + One must-matching pattern. + right: tvm.relax.dpl.DFPattern + One must-matching pattern. + """ + + def __init__(self, left: "DFPattern", right: "DFPattern"): + self.__init_handle_by_constructor__(ffi.AndPattern, left, right) + + +@register_df_node +class NotPattern(DFPattern): + """Create a Pattern that matches the negation of a condition. + + Parameters + ---------- + to_reject: tvm.relax.dpl.DFPattern + The pattern to deny. + """ + + def __init__(self, to_reject: "DFPattern"): + self.__init_handle_by_constructor__(ffi.NotPattern, to_reject) + + +@register_df_node +class WildcardPattern(DFPattern): + """A pattern which matches anything.""" + + def __init__(self): + self.__init_handle_by_constructor__(ffi.WildcardPattern) + + +@register_df_node +class TypePattern(DFPattern): + """A pattern that matches another pattern with a certain type annotation. + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern that needs type annotation. + + ttype: tvm.ir.type.Type + The type to match. + """ + + def __init__(self, pattern: "DFPattern", ttype: tvm.ir.type.Type): + self.__init_handle_by_constructor__(ffi.TypePattern, pattern, ttype) + + +@register_df_node +class DataTypePattern(DFPattern): + """A pattern that matches another pattern with certain data type + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern that needs type annotation. + + dtype: str + The dtype to match. + """ + + def __init__(self, pattern: "DFPattern", dtype: str): + self.__init_handle_by_constructor__(ffi.DataTypePattern, pattern, dtype) + + +@register_df_node +class ShapePattern(DFPattern): + """A pattern that matches another pattern with a certain tensor shape + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern that needs type annotation. + + shape: List[tvm.ir.PrimExpr] + The shape to match. + """ + + def __init__(self, pattern: "DFPattern", shape: List[tvm.ir.PrimExpr]): + self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape) + + +@register_df_node +class PrimArrPattern(DFPattern): + """ + A pattern to match an array of PrimExpr + + Parameters + ---------- + shape : List[tvm.ir.PrimExpr] + The shape to match. + """ + + def __init__(self, shape: List[tvm.ir.PrimExpr]): + self.__init_handle_by_constructor__(ffi.PrimArrPattern, shape) + + def __getitem__(self, index: int): + if index >= len(self): + raise IndexError("PrimArrPattern index out of range") + return self.fields[index] + + def __len__(self): + return len(self.fields) + + +@register_df_node +class AttrPattern(DFPattern): + """Get match an expression with a certain attributes. + Currently only supports Op Attributes, not call Attributes. + + Parameters + ---------- + pattern: tvm.relax.dpl.DFPattern + The input pattern. + + attrs: tvm.ir.attrs.Attrs + The attributes to match. + """ + + def __init__(self, pattern: "DFPattern", attrs: tvm.ir.attrs.Attrs): + self.__init_handle_by_constructor__(ffi.AttrPattern, pattern, attrs) + + +def is_var(name: str = "") -> VarPattern: + """ + Syntatic sugar for creating an optionally named VarPattern. + + Parameters + ---------- + name: str + The name of the input pattern to match. + + Returns + ------- + result: tvm.relax.dpl.VarPattern + The resulting pattern. + """ + return VarPattern(name) + + +def is_gv(name: str = "") -> GlobalVarPattern: + """Syntax sugar for creating an optionally (if name is empty) named GlobalVarPattern.""" + return GlobalVarPattern(name) + + +def is_dfv(name: str = "") -> DataflowVarPattern: + """Syntax sugar for creating an optionally (if name is empty) named DataflowVarPattern.""" + return DataflowVarPattern(name) + + +def is_const() -> ConstantPattern: + """ + Syntatic sugar for creating a ConstantPattern. + + Parameters + ---------- + name: str + The name of the input pattern to match. + + Returns + ------- + result: tvm.relax.dpl.ConstantPattern + The resulting pattern. + """ + return ConstantPattern() + + +def is_expr(expr: Expr) -> ExprPattern: + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + expr: Expr + The Relax expression to match. + + Returns + ------- + result: tvm.relax.dpl.ExprPattern + The resulting pattern. + """ + return ExprPattern(expr) + + +def is_op(op_name: str) -> ExprPattern: + """ + Syntatic sugar for creating an operator ExprPattern. + + Parameters + ---------- + op_name: String + The name of the tvm.ir.op.Op object + + Returns + ------- + result: tvm.relax.dpl.ExprPattern + The resulting ExprPattern + """ + op = get(op_name) + return ExprPattern(op) + + +def is_tuple( + fields: Union[Array, List, Tuple], unordered=False +) -> Union[TuplePattern, UnorderedTuplePattern]: + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + fields : Array[tvm.relax.dpl.DFPattern] + The fields in the tuple. + + Returns + ------- + result: tvm.relax.dpl.DFPattern + The resulting pattern. + """ + if not isinstance(fields, (list, tuple, Array)): + raise ValueError("fields must be a list, tuple, or Array") + if unordered: + return UnorderedTuplePattern(fields) + return TuplePattern(fields) + + +def is_tuple_get_item(tuple_value: DFPattern, index: Optional[int] = None) -> TupleGetItemPattern: + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + tuple_value: tvm.relax.dpl.DFPattern + The input tuple expression. + + index: Optional[int] + The index to match; Default (None) to match a TupleGetItem with any index. + + Returns + ------- + result: tvm.relax.dpl.TupleGetItemPattern + The resulting pattern. + """ + return TupleGetItemPattern(tuple_value, index) + + +def wildcard() -> WildcardPattern: + """ + Syntatic sugar for creating a WildcardPattern. + + Returns + ------- + result: tvm.relax.dpl.WildcardPattern + The resulting pattern. + """ + return WildcardPattern() + + +def has_dtype(dtype: str, pattern: DFPattern = None) -> DataTypePattern: + """ + Syntatic sugar for creating a DataTypePattern + + Parameters + ---------- + dtype: str + The dtype to match + + pattern: tvm.relax.dpl.DFPattern + The pattern that needs type annotation + + Returns + ------- + result: tvm.relax.dpl.DataTypePattern + The resulting DataTypePattern + """ + if pattern is None: + pattern = wildcard() + return DataTypePattern(pattern, dtype) + + +def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern": + """ + Directly matches a shape which is an array of PrimExpr + + Parameters + ---------- + shape : List[tvm.ir.PrimExpr] + The expected shape + + Returns + ------- + PrimArrPattern + The resulting PrimArrPattern pattern + + Raises + ------ + ValueError + If the argument shape is not a list/tuple/tvm.ir.Array + + Note + ---- + The difference between p.has_shape(s) and is_shape(s) is that: has_shape + puts assumptions on the shape of the tensor matched by pattern p. While + is_shape directly matches the shape (an array of PrimExpr). + """ + if not isinstance(shape, (list, tuple, tvm.ir.Array)): + raise ValueError("is_shape takes a list or tuple as input.") + return PrimArrPattern(shape) + + +def is_call_tir( + func_name: str, + args: Union[List, Tuple, TuplePattern] = None, + shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None, +) -> CallPattern: + """ + Syntax sugar for creating a CallPattern for call_tir + + Parameters + ---------- + func_name : str + Name of the CPS function to call. + args : Union[List[DFPattern], Tuple[DFPattern]], optional + Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments + shape : Union[Tuple, List[tvm.ir.PrimExpr], DFPattern], optional + Shape (or shapes in a tuple) of the output, by default None meaning arbitrary shape(s) + + Returns + ------- + CallPattern + The resulting CallPattern + """ + if args is None: + args = wildcard() + elif isinstance(args, (list, tuple)): + args = TuplePattern(args) + + if shape is None: + shape = wildcard() + elif isinstance(shape, (list, Array)): + shape = PrimArrPattern(shape) + elif isinstance(shape, (tuple)): + shape = is_tuple(shape) # multiple shape patterns + + return is_op("relax.call_tir")(GlobalVarPattern(func_name), args, shape) + + +def is_call_packed( + func_name: str, args: Union[List[DFPattern], Tuple[DFPattern]] = None +) -> CallPattern: + """ + Syntax sugar for creating a CallPattern for call_packed + + Parameters + ---------- + func_name : str + Name of the external function to call + args : Union[List[DFPattern], Tuple[DFPattern]], optional + Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments + + Returns + ------- + CallPattern + The resulting CallPattern + """ + if args is None: + return ExternFuncPattern(func_name)(varg_default_wildcard=True) + return ExternFuncPattern(func_name)(*args) + + +def reject(pattern: DFPattern) -> NotPattern: + """ + Syntax sugar for creating a DFPattern to reject + + Parameters + ---------- + pattern : DFPattern + The pattern to deny + + Returns + ------- + result: NotPattern + The resulting NotPattern + """ + return NotPattern(pattern) + + +def has_attr(attrs, pattern=None) -> AttrPattern: + """ + Syntatic sugar for creating an AttrPattern + + Parameters + ---------- + attrs: Dict[str, Object] + The attributes to match + + pattern: Optional[tvm.relax.dpl.DFPattern] + The input pattern. + + Returns + ------- + result: tvm.relax.dpl.DFPattern + The resulting AttrPattern + """ + if pattern is None: + pattern = wildcard() + return pattern.has_attr(attrs) + + +@register_df_node +class PatternSeq(Node): + """A sequence of patterns with consecutive constraints""" + + def __init__(self, patterns: List[DFPattern], only_use=False): + """ + Initializer to PatternSeq + + Parameters + ---------- + patterns : List[DFPattern] + A chain of patterns + only_use : bool, optional + Whether the patterns follows only-used-by relations consecutively, by default False + """ + self.__init_handle_by_constructor__(ffi.PatternSeq, patterns, only_use) + + def used_by(self, other: Union[DFPattern, "PatternSeq"], index=-1) -> "PatternSeq": + """ + Assuming the right-most pattern must be used by the `other` pattern as a producer + + Parameters + ---------- + other : Union[DFPattern, PatternSeq] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + PatternSeq + A chained pattern sequence + + Note + ---- + If other is PatternSeq, it means the right-most pattern must be used by the left-most + pattern of the other sequence. + """ + return _used_by(self, other, index) + + def only_used_by(self, other: Union[DFPattern, "PatternSeq"], index=-1) -> "PatternSeq": + + """ + Assuming the right-most pattern must be **ONLY** used by the `other` pattern as a producer + + Parameters + ---------- + other : Union[DFPattern, PatternSeq] + The consumer pattern (sequence) + index : int, optional + The argument index called by the consumer pattern, by default -1 + + Returns + ------- + PatternSeq + A chained pattern sequence + + Note + ---- + If other is PatternSeq, it means the right-most pattern must be **ONLY** used by the + left-most pattern of the other sequence. + """ + return _only_used_by(self, other, index) + + def __getitem__(self, index: int) -> DFPattern: + """ + Access the pattern at the given index + + Parameters + ---------- + index : int + Index of the accessed pattern + + Returns + ------- + DFPattern + The accessed pattern + """ + return self.patterns[index] + + def __xor__(self, other) -> "PatternSeq": + """Syntax sugar of PatternSeq.used_by""" + return self.used_by(other, -1) + + def __rshift__(self, other) -> "PatternSeq": + """Syntax sugar of PatternSeq.only_used_by""" + return self.only_used_by(other, -1) + + def dup(self) -> "PatternSeq": + """ + Duplicate the pattern sequence (new object under different address) + + Returns + ------- + PatternSeq + A duplicated chain + """ + return ffi.dup_seq(self) + + +### Private functions + + +def _used_by( + lhs: Union[DFPattern, PatternSeq], + rhs: Union[DFPattern, PatternSeq], + index=-1, +) -> PatternSeq: + if isinstance(lhs, DFPattern): + lhs = PatternSeq([lhs]) + if isinstance(rhs, DFPattern): + rhs = PatternSeq([rhs]) + return ffi.used_by(lhs, rhs, index) + + +def _only_used_by( + lhs: Union[DFPattern, PatternSeq], rhs: Union[DFPattern, PatternSeq], index=-1 +) -> PatternSeq: + if isinstance(lhs, DFPattern): + lhs = PatternSeq([lhs]) + if isinstance(rhs, DFPattern): + rhs = PatternSeq([rhs]) + return ffi.only_used_by(lhs, rhs, index) diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py new file mode 100644 index 0000000000..4f3cabe9bf --- /dev/null +++ b/python/tvm/relax/exec_builder.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""A builder to build Relax VM executable.""" +from enum import IntEnum +from typing import Optional, Union, List +import tvm +from tvm._ffi._ctypes.packed_func import TVMRetValueHandle +from tvm.runtime import Object +from tvm.runtime.container import ShapeTuple +from .vm import Executable +from . import _ffi_api + + +class SpecialReg(IntEnum): + """Magic numbers that represent special registers in vm.""" + + VOID_ARG = 0x00EC66FE0321975A + VM_STATE = 0x008D14FA4379015C + + +class VMFuncScope(object): + """An object corresponds to each VM function, working as a context manager.""" + + stack = [] + + def __enter__(self): + VMFuncScope.stack.append(self) + return self + + def __exit__(self, ptype, value, trace): + VMFuncScope.stack.pop() + + +@tvm._ffi.register_object("relax.ExecBuilder") +class ExecBuilder(Object): + """A builder to emit instructions and build executable for the virtual machine.""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__(_ffi_api.ExecBuilderCreate) + + def r(self, idx: int) -> int: + """set instruction's argument as a register.""" + return _ffi_api.ExecBuilderR(self, idx) + + def imm(self, value: int) -> int: + """set instruction's argument as an immediate.""" + return _ffi_api.ExecBuilderImm(self, value) + + def c(self, idx: int) -> int: + """set instruction's argument as a constant.""" + return _ffi_api.ExecBuilderC(self, idx) + + def void_arg(self) -> int: + return self.r(SpecialReg.VOID_ARG) + + def vm_state(self) -> int: + return self.r(SpecialReg.VM_STATE) + + def function( + self, func_name: str, num_inputs: Optional[int] = 0, param_names: List[str] = None + ) -> VMFuncScope: + """annotate a VM function.""" + _ffi_api.ExecBuilderFunction(self, func_name, num_inputs, param_names) + return VMFuncScope() + + def _check_scope(self) -> None: + if len(VMFuncScope.stack) == 0: + raise ValueError("emit should happen in a function scope") + + def emit_constant(self, const: TVMRetValueHandle) -> int: + return _ffi_api.ExecBuilderEmitConstant(self, const) + + def emit_call( + self, + name: str, + args: Optional[List[Union[tvm.nd.NDArray, tvm.DataType]]] = None, + dst: int = None, + ) -> None: + """emit a call instruction which calls a packed function.""" + self._check_scope() + if dst is None: + dst = SpecialReg.VOID_ARG + args_ = [] + if args is not None: + for arg in args: + if isinstance(arg, tuple): + shape_tuple = ShapeTuple(arg) + new_arg = self.emit_constant(shape_tuple) + args_.append(new_arg) + elif isinstance(arg, (tvm.nd.NDArray, tvm.DataType, ShapeTuple)): + new_arg = self.emit_constant(arg) + args_.append(new_arg) + else: + args_.append(arg) + _ffi_api.ExecBuilderEmitCall(self, name, args_, dst) + + def emit_ret(self, result: int) -> None: + """emit a return instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitRet(self, result) + + def emit_goto(self, pc_offset): + """emit a goto instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitGoto(self, pc_offset) + + def emit_if(self, cond, false_offset): + """emit an if instruction""" + self._check_scope() + _ffi_api.ExecBuilderEmitIf(self, cond, false_offset) + + def get(self) -> Executable: + """return the executable""" + return Executable(_ffi_api.ExecBuilderGet(self)) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py new file mode 100644 index 0000000000..18b4921289 --- /dev/null +++ b/python/tvm/relax/expr.py @@ -0,0 +1,294 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import, super-init-not-called +# pylint: disable=redefined-builtin +"""The expression nodes of Relax.""" +from typing import Any, List, Optional, Union + +import tvm +import tvm._ffi + +from .. import relay +from ..ir import BaseFunc, Node, SourceName, Span +from ..relay import Id, Tuple, TupleGetItem +from ..runtime import String +from ..tir import PrimExpr +from . import _ffi_api, ty + +Expr = relay.Expr +Type = relay.Type +GlobalVar = relay.GlobalVar +Call = relay.Call +If = relay.If +const = relay.const +Constant = relay.Constant + + +@tvm._ffi.register_object("relax.expr.ShapeExpr") +class ShapeExpr(Expr): + """A shape expression which allows users to construct a shape containing PrimExpr.""" + + values: List[PrimExpr] + + def __init__(self, values: List[PrimExpr], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values, span) + + def __getitem__(self, index): + if index >= len(self): + raise IndexError("Tuple index out of range") + return self.values[index] + + def __len__(self): + return len(self.values) + + +def make_shape(shape: List[PrimExpr]) -> ShapeExpr: + if isinstance(shape, (list, tuple)): + return ShapeExpr(shape) + raise ValueError("Wrong type") + + +@tvm._ffi.register_object("relax.expr.RuntimeDepShape") +class RuntimeDepShape(Expr): + """A shape expression which allows users to construct a runtime dependent shape.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.RuntimeDepShape, span) + + +@tvm._ffi.register_object("relax.expr.Var") +class Var(Expr): + """The variable class for all Relax bindings.""" + + vid: Id + type_annotation: Optional[Type] + + def __init__( + self, + name_hint: str, + shape_annotation: Optional[Expr] = None, + type_annotation: Optional[Type] = None, + span: Span = None, + ) -> None: + if isinstance(shape_annotation, (list, tuple)): + shape_annotation = make_shape(shape_annotation) + self.__init_handle_by_constructor__( + _ffi_api.Var if isinstance(name_hint, str) else _ffi_api.VarFromId, + name_hint, + shape_annotation, + type_annotation, + span, + ) + + @property + def name_hint(self): + """Get name hint of the current var.""" + name = str(self.vid.name_hint) + return name + + def __call__(self, *args: Any, attrs=None) -> Call: + if self.checked_type and isinstance(self.checked_type, ty.FuncType): + return Call(self, args, attrs=attrs) + else: + raise TypeError("Only vars with function type can be called") + + +@tvm._ffi.register_object("relax.expr.DataflowVar") +class DataflowVar(Var): + """A sub-type of the variable node used to mark dataflow variables from + normal visible "function local" bindings.""" + + def __init__( + self, + name_hint: Union[str, Id], + shape_annotation: Optional[Expr] = None, + type_annotation: Optional[Type] = None, + span: Span = None, + ) -> None: + if isinstance(shape_annotation, (list, tuple)): + shape_annotation = make_shape(shape_annotation) + + self.__init_handle_by_constructor__( + _ffi_api.DataflowVar if isinstance(name_hint, str) else _ffi_api.DataflowVarFromId, + name_hint, + shape_annotation, + type_annotation, + span, + ) + + +@tvm._ffi.register_object("relax.expr.Binding") +class Binding(Node): + """The base class of a binding in Relax.""" + + ... + + +@tvm._ffi.register_object("relax.expr.MatchShape") +class MatchShape(Binding): + """Symbolic shape match, binds the variable of the lhs with the rhs.""" + + value: Expr + pattern: List[PrimExpr] + var: Var + + def __init__(self, value: Expr, pattern: List[PrimExpr], var: Var, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.MatchShape, value, pattern, var, span) + + +@tvm._ffi.register_object("relax.expr.VarBinding") +class VarBinding(Binding): + """Variable binding, bind he variable of the lhs with the rhs.""" + + var: Var + value: Expr + + def __init__(self, var: Var, value: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value, span) + + +@tvm._ffi.register_object("relax.expr.BindingBlock") +class BindingBlock(Node): + """base class of binding block, bindings inside can be impure + (with side effect or control flow)""" + + bindings: List[Binding] + + def __init__(self, bindings: List[Binding], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings, span) + + +@tvm._ffi.register_object("relax.expr.DataflowBlock") +class DataflowBlock(BindingBlock): + """dataflow block, bindings inside are pure (no side effect and no control flow)""" + + def __init__(self, bindings: List[Binding], span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.DataflowBlock, bindings, span) + + +@tvm._ffi.register_object("relax.expr.SeqExpr") +class SeqExpr(Expr): + """A sequence of binding blocks followed by an expression.""" + + blocks: List[BindingBlock] + body: Expr + + def __init__(self, blocks: List[BindingBlock], body: Expr, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body, span) + + +@tvm._ffi.register_object("relax.expr.Function") +class Function(BaseFunc): + """A Relax function.""" + + params: List[Var] + body: Expr + ret_type: Type + ret_shape: Expr + attrs: Optional[tvm.ir.DictAttrs] + + def __init__( + self, + params: List[Var], + body: Expr, + ret_type: Type, + ret_shape: Expr, + attrs: Optional[tvm.ir.DictAttrs] = None, + span: Optional[Span] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.Function, params, body, ret_type, ret_shape, attrs, span + ) + + @staticmethod + def create_unchecked( + params: List[Var], + body: Expr, + ret_type: Type, + ret_shape: Expr, + attrs: Optional[tvm.ir.DictAttrs] = None, + span: Optional[Span] = None, + ): + """Construct a relax.Function but without type checking.""" + return _ffi_api.Function_CreateUnchecked(params, body, ret_type, ret_shape, attrs, span) + + def __call__(self, *args): + """Invoke the global function. + + Parameters + ---------- + args: List[relax.Expr] + Arguments. + """ + return Call(self, args, None, None) + + def script(self, show_meta: bool = False) -> str: + """Print relax.Function into TVMScript + + Parameters + ---------- + show_meta : bool + Whether to show meta information + + Returns + ------- + script : str + The TVM Script of the relax.Function + """ + return tvm._ffi.get_global_func("script.AsRelaxScript")(self, show_meta) # type: ignore + + def show(self, style: str = "light") -> None: + """ + A sugar for print highlighted TVM script. + + Parameters + ---------- + style : str, optional + Pygments styles extended by "light" (default) and "dark", by default "light" + """ + from tvm.script.highlight import cprint # pylint: disable=import-outside-toplevel + + # Use deferred import to avoid circular import while keeping cprint under tvm/script + cprint(self, style=style) + + +@tvm._ffi.register_object("relax.expr.ExternFunc") +class ExternFunc(BaseFunc): + """extern function, which can represent a TIR PrimFunc or a PackedFunc.""" + + global_symbol: String + + def __init__(self, global_symbol: String, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ExternFunc, global_symbol, span) + + +def extern(name: str, span: Span = None): + """Create extern function.""" + return ExternFunc(name, span) + + +def te_tensor(value: Expr, name: str = "rxplaceholder"): + """Create te tensor from relax expression.""" + return _ffi_api.TETensor(value, name) + + +def _update_type(expr: Expr, type: Type) -> None: + _ffi_api.UpdateType(expr, type) + + +def _update_shape(expr: Expr, shape: Optional[tvm.runtime.Object]) -> None: + _ffi_api.UpdateShape(expr, shape) diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py new file mode 100644 index 0000000000..38129faa76 --- /dev/null +++ b/python/tvm/relax/expr_functor.py @@ -0,0 +1,1466 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, arguments-differ +"""The expression functor of Relax.""" +from typing import Optional, Callable + +import tvm +from tvm.runtime import Object +from tvm.ir import Op +from tvm.meta_schedule.utils import derived_object + +from .expr import Type, Span, Expr +from .expr import Function, ExternFunc +from .expr import Constant, Var, DataflowVar +from .expr import ShapeExpr, RuntimeDepShape +from .expr import GlobalVar, SeqExpr, Tuple +from .expr import Call, If, TupleGetItem +from .expr import Binding, MatchShape, VarBinding +from .expr import BindingBlock, DataflowBlock +from ..relay import Id +from ..ir.module import IRModule +from .block_builder import BlockBuilder +from . import _ffi_api + +visitor = derived_object +""" +A decorator to wrap user-customized PyExprVisitor as TVM object _PyExprVisitor. + +Parameters +---------- +visitor_cls : PyExprVisitor + The user-customized PyExprVisitor. + +Returns +------- +cls : _PyExprVisitor + The decorated TVM object _PyExprVisitor(ExprVisitor on the C++ side). + +Example +------- +.. code-block:: python + + @relax.expr_functor.visitor + class MyExprVisitor(PyExprVisitor): + # customize visit function + def visit_call_(self, op: Call) -> None: + # just for demo purposes + ... + # myvisitor is now a special visitor that visit every Call with + # user-customized visit_call_ + myvisitor = MyExprVisitor() + # apply myvisitor to Expr/Binding/BindingBlock/VarDef + myvisitor.visit_expr(expr) + myvisitor.visit_binding(binding) + myvisitor.visit_binding_block(bindingblock) + myvisitor.visit_var_def(var) +""" + +mutator = derived_object +""" +A decorator to wrap user-customized PyExprMutator as TVM object _PyExprMutator. +Note: Cannot override visit function and post-order rewrite at the same time. + +Parameters +---------- +mutator_cls : PyExprMutator + The user-customized PyExprMutator. + +Returns +------- +cls : _PyExprMutator + The decorated TVM object _PyExprMutator(ExprMutator on the C++ side). + +Example +------- +.. code-block:: python + + @relax.expr_functor.mutator + class MyExprMutator(PyExprMutator): + # customize rewrite function + def visit_tuple_(self, op: Tuple) -> Expr: + # just for demo purposes + ... + + # mymutator is now a special mutator that rewrite every Tuple with + # user-customized visit_tuple_ + mymutator = MyExprMutator() + # apply mymutator to Expr/Binding/BindingBlock/VarDef + mymutator.visit_expr(expr) + mymutator.visit_binding(binding) + mymutator.visit_binding_block(bindingblock) + mymutator.visit_var_def(var) +""" + + +class ExprFunctor: + """ + An abstract visitor defined over Expr. + Defines the default dispatch over expressions, and + implements memoization. + """ + + def visit_expr(self, expr): + """Apply the visitor to an expression.""" + if isinstance(expr, Constant): + ret = self.visit_constant_(expr) + elif isinstance(expr, Tuple): + ret = self.visit_tuple_(expr) + elif isinstance(expr, DataflowVar): + ret = self.visit_dataflow_var_(expr) + elif isinstance(expr, Var): + ret = self.visit_var_(expr) + elif isinstance(expr, ShapeExpr): + ret = self.visit_shape_expr_(expr) + elif isinstance(expr, RuntimeDepShape): + ret = self.visit_runtime_dep_shape_(expr) + elif isinstance(expr, ExternFunc): + ret = self.visit_extern_func_(expr) + elif isinstance(expr, GlobalVar): + ret = self.visit_global_var_(expr) + elif isinstance(expr, Function): + ret = self.visit_function_(expr) + elif isinstance(expr, Call): + ret = self.visit_call_(expr) + elif isinstance(expr, SeqExpr): + ret = self.visit_seq_expr_(expr) + elif isinstance(expr, If): + ret = self.visit_if_(expr) + elif isinstance(expr, Op): + ret = self.visit_op_(expr) + elif isinstance(expr, TupleGetItem): + ret = self.visit_tuple_getitem_(expr) + else: + raise TypeError("Invalid type: {0}".format(type(expr))) + + return ret + + def visit_constant_(self, op: Constant): + raise NotImplementedError() + + def visit_tuple_(self, op: Tuple): + raise NotImplementedError() + + def visit_dataflow_var_(self, op: DataflowVar): + raise NotImplementedError() + + def visit_var_(self, op: Var): + raise NotImplementedError() + + def visit_shape_expr_(self, op: ShapeExpr): + raise NotImplementedError() + + def visit_runtime_dep_shape_(self, op: RuntimeDepShape): + raise NotImplementedError() + + def visit_extern_func_(self, op: ExternFunc): + raise NotImplementedError() + + def visit_global_var_(self, op: GlobalVar): + raise NotImplementedError() + + def visit_function_(self, op: Function): + raise NotImplementedError() + + def visit_call_(self, op: Call): + raise NotImplementedError() + + def visit_seq_expr_(self, op: SeqExpr): + raise NotImplementedError() + + def visit_if_(self, op: If): + raise NotImplementedError() + + def visit_op_(self, op: Op): + raise NotImplementedError() + + def visit_tuple_getitem_(self, op: TupleGetItem): + raise NotImplementedError() + + def visit_var_binding_(self, binding: VarBinding) -> None: + raise NotImplementedError() + + def visit_match_shape_(self, binding: MatchShape) -> None: + raise NotImplementedError() + + def visit_binding_block_(self, block: BindingBlock) -> None: + raise NotImplementedError() + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + raise NotImplementedError() + + def visit_var_def_(self, var: Var) -> None: + raise NotImplementedError() + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + raise NotImplementedError() + + def visit_binding(self, binding: Binding) -> None: + if isinstance(binding, MatchShape): + self.visit_match_shape_(binding) + elif isinstance(binding, VarBinding): + self.visit_var_binding_(binding) + else: + raise TypeError("Invalid type: {0}".format(type(binding))) + + def visit_binding_block(self, block: BindingBlock) -> None: + if isinstance(block, DataflowBlock): + self.visit_dataflow_block_(block) + elif isinstance(block, BindingBlock): + self.visit_binding_block_(block) + else: + raise TypeError("Invalid type: {0}".format(type(block))) + + def visit_var_def(self, var: Var): + if isinstance(var, DataflowVar): + self.visit_dataflow_var_def_(var) + elif isinstance(var, Var): + self.visit_var_def_(var) + else: + raise TypeError("Invalid type: {0}".format(type(var))) + + +@tvm._ffi.register_object("expr_functor.PyExprVisitor") +class _PyExprVisitor(Object): + """ + A TVM object to support customization of ExprVisitor on the python side. + This is the decorated result returned from visitor decorator. + + WARNING: This is NOT the user facing class for method overwriting inheritance. + + See also: visitor, PyExprVisitor + """ + + def __init__( + self, + f_visit_expr: Callable = None, + f_visit_constant_: Callable = None, + f_visit_tuple_: Callable = None, + f_visit_var_: Callable = None, + f_visit_dataflow_var_: Callable = None, + f_visit_shape_expr_: Callable = None, + f_visit_runtime_dep_shape_: Callable = None, + f_visit_extern_func_: Callable = None, + f_visit_global_var_: Callable = None, + f_visit_function_: Callable = None, + f_visit_call_: Callable = None, + f_visit_seq_expr_: Callable = None, + f_visit_if_: Callable = None, + f_visit_op_: Callable = None, + f_visit_tuple_getitem_: Callable = None, + f_visit_binding: Callable = None, + f_visit_var_binding_: Callable = None, + f_visit_match_shape_: Callable = None, + f_visit_binding_block: Callable = None, + f_visit_binding_block_: Callable = None, + f_visit_dataflow_block_: Callable = None, + f_visit_var_def: Callable = None, + f_visit_var_def_: Callable = None, + f_visit_dataflow_var_def_: Callable = None, + f_visit_type: Callable = None, + f_visit_span: Callable = None, + ) -> None: + """Constructor.""" + + self.__init_handle_by_constructor__( + _ffi_api.MakePyExprVisitor, + f_visit_expr, + f_visit_constant_, + f_visit_tuple_, + f_visit_var_, + f_visit_dataflow_var_, + f_visit_shape_expr_, + f_visit_runtime_dep_shape_, + f_visit_extern_func_, + f_visit_global_var_, + f_visit_function_, + f_visit_call_, + f_visit_seq_expr_, + f_visit_if_, + f_visit_op_, + f_visit_tuple_getitem_, + f_visit_binding, + f_visit_var_binding_, + f_visit_match_shape_, + f_visit_binding_block, + f_visit_binding_block_, + f_visit_dataflow_block_, + f_visit_var_def, + f_visit_var_def_, + f_visit_dataflow_var_def_, + f_visit_type, + f_visit_span, + ) + + def visit_expr(self, expr: Expr) -> None: + """Generic dispatcher for Expr. + + Parameters + ---------- + expr : Expr + The expr to be visited. + """ + return _ffi_api.PyExprVisitorVisitExpr(self, expr) + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + return _ffi_api.PyExprVisitorVisitBinding(self, binding) + + def visit_binding_block(self, block: BindingBlock) -> None: + """Generic dispatcher for BindingBlock. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + """ + return _ffi_api.PyExprVisitorVisitBindingBlock(self, block) + + def visit_var_def(self, var: Var) -> None: + """Generic dispatcher for visiting the var definition site. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + """ + return _ffi_api.PyExprVisitorVisitVarDef(self, var) + + +class PyExprVisitor: + """ + An abstract ExprVisitor with customized methods on the python-side. + This is the user facing class for method overwriting inheritance. + _tvm_metadata discribes the class to inherit("cls"), the methods + that users can overwrite("methods"). + + Note: @relax.expr_functor.visitor is required for proper usage of any inherited class. + + See also: visitor, _PyExprVisitor + + Example: + @relax.expr_functor.visitor + def MyExprVisitor(PyExprVisitor): + ... + """ + + _tvm_metadata = { + "cls": _PyExprVisitor, + "methods": [ + "visit_expr", + "visit_constant_", + "visit_tuple_", + "visit_var_", + "visit_dataflow_var_", + "visit_shape_expr_", + "visit_runtime_dep_shape_", + "visit_extern_func_", + "visit_global_var_", + "visit_function_", + "visit_call_", + "visit_seq_expr_", + "visit_if_", + "visit_op_", + "visit_tuple_getitem_", + "visit_binding", + "visit_var_binding_", + "visit_match_shape_", + "visit_binding_block", + "visit_binding_block_", + "visit_dataflow_block_", + "visit_var_def", + "visit_var_def_", + "visit_dataflow_var_def_", + "visit_type", + "visit_span", + ], + } + + def visit_expr(self, expr: Expr) -> None: + """Generic dispatcher for Expr. + Users can customized this function to overwrite VisitExpr(const Expr& expr) on the C++ side. + + Parameters + ---------- + expr : Expr + The expr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitExpr(self._outer(), expr) + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + Users can customized this function to overwrite VisitBinding(const Binding& binding) + on the C++ side. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitBinding(self._outer(), binding) + + def visit_binding_block(self, block: BindingBlock) -> None: + """Generic dispatcher for BindingBlock. + Users can customized this function to overwrite VisitBindingBlock(const BindingBlock& block) + on the C++ side. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitBindingBlock(self._outer(), block) + + def visit_var_def(self, var: Var) -> None: + """Generic dispatcher for visiting the var definition site. + Users can customized this function to overwrite VisitVarDef(const Var& var) on the C++ side. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitVarDef(self._outer(), var) + + def visit_constant_(self, op: Constant) -> None: + """Visit Constant. + Users can customized this function to overwrite VisitExpr_(const ConstantNode* op) + on the C++ side. + + Parameters + ---------- + op : Constant + The Constant to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_tuple_(self, op: Tuple) -> None: + """Visit Tuple. + Users can customized this function to overwrite VisitExpr_(const TupleNode* op) + on the C++ side. + + Parameters + ---------- + op : Tuple + The Tuple to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_var_(self, op: Var) -> None: + """Visit Var. + Users can customized this function to overwrite VisitExpr_(const VarNode* op) + on the C++ side. + + Parameters + ---------- + op : Var + The Var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_dataflow_var_(self, op: DataflowVar) -> None: + """Visit DataflowVar. + Users can customized this function to overwrite VisitExpr_(const DataflowVarNode* op) + on the C++ side. + + Parameters + ---------- + op : DataflowVar + The DataflowVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_shape_expr_(self, op: ShapeExpr) -> None: + """Visit ShapeExpr. + Users can customized this function to overwrite VisitExpr_(const ShapeExprNode* op) + on the C++ side. + + Parameters + ---------- + op : ShapeExpr + The ShapeExpr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> None: + """Visit RuntimeDepShape. + Users can customized this function to overwrite VisitExpr_(const RuntimeDepShapeNode* op) + on the C++ side. + + Parameters + ---------- + op : RuntimeDepShape + The RuntimeDepShape to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_extern_func_(self, op: ExternFunc) -> None: + """Visit ExternFunc. + Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) + on the C++ side. + + Parameters + ---------- + op : ExternFunc + The ExternFunc to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_global_var_(self, op: GlobalVar) -> None: + """Visit GlobalVar. + Users can customized this function to overwrite VisitExpr_(const GlobalVarNode* op) + on the C++ side. + + Parameters + ---------- + op : GlobalVar + The GlobalVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_function_(self, op: Function) -> None: + """Visit Function. + Users can customized this function to overwrite VisitExpr_(const FunctionNode* op) + on the C++ side. + + Parameters + ---------- + op : Function + The Function to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_call_(self, op: Call) -> None: + """Visit Call. + Users can customized this function to overwrite VisitExpr_(const CallNode* op) + on the C++ side. + + Parameters + ---------- + op : Call + The Call to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_seq_expr_(self, op: SeqExpr) -> None: + """Visit SeqExpr. + Users can customized this function to overwrite VisitExpr_(const SeqExprNode* op) + on the C++ side. + + Parameters + ---------- + op : SeqExpr + The SeqExpr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_if_(self, op: If) -> None: + """Visit If. + Users can customized this function to overwrite VisitExpr_(const IfNode* op) + on the C++ side. + + Parameters + ---------- + op : If + The If to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_op_(self, op: Op) -> None: + """Visit Op. + Users can customized this function to overwrite VisitExpr_(const OpNode* op) + on the C++ side. + + Parameters + ---------- + op : Op + The Op to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_tuple_getitem_(self, op: TupleGetItem) -> None: + """Visit TupleGetItem. + Users can customized this function to overwrite VisitExpr_(const TupleGetItemNode* op) + on the C++ side. + + Parameters + ---------- + op : TupleGetItem + The TupleGetItem to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Visit VarBinding. + Users can customized this function to overwrite VisitBinding_(const VarBindingNode* binding) + on the C++ side. + + Parameters + ---------- + binding : VarBinding + The VarBinding to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) + + def visit_match_shape_(self, binding: MatchShape) -> None: + """Visit MatchShape. + Users can customized this function to overwrite VisitBinding_(const MatchShapeNode* binding) + on the C++ side. + + Parameters + ---------- + binding : MatchShape + The MatchShape to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) + + def visit_binding_block_(self, block: BindingBlock) -> None: + """Visit BindingBlock. + Users can customized this function to overwrite VisitBindingBlock_(const BindingBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : BindingBlock + The BindingBlock to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBindingBlock(self._outer(), block) + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + """Visit DataflowBlock. + Users can customized this function to overwrite VisitBindingBlock_(const DataflowBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : DataflowBlock + The DataflowBlock to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBindingBlock(self._outer(), block) + + def visit_var_def_(self, var: Var) -> None: + """Visit the Var definition site. + Users can customized this function to overwrite VisitVarDef_(const VarNode* var) + on the C++ side. + + Parameters + ---------- + var : Var + The Var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitVarDef(self._outer(), var) + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + """Visit the DataflowVar definition site. + Users can customized this function to overwrite VisitVarDef_(const DataflowVarNode* var) + on the C++ side. + + Parameters + ---------- + var : DataflowVar + The DataflowVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitVarDef(self._outer(), var) + + def visit_type(self, t: Type) -> None: + """Visit Type. + Users can customized this function to overwrite VisitType(const Type& t) on the C++ side. + + Parameters + ---------- + t : Type + The Type to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitType(self._outer(), t) + + def visit_span(self, span: Span) -> None: + """Visit Span. + Users can customized this function to overwrite VisitSpan(const Span& span) on the C++ side. + + Parameters + ---------- + span : Span + The Span to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitSpan(self._outer(), span) + + +@tvm._ffi.register_object("expr_functor.PyExprMutator") +class _PyExprMutator(Object): + """ + A TVM object to support customization of ExprMutator on the python side. + This is the decorated result returned from mutator decorator. + + WARNING: This is NOT the user facing class for method overwriting inheritance. + + See also: mutator, PyExprmutator + """ + + def __init__( + self, + builder: BlockBuilder = None, + f_visit_expr: Callable = None, + f_visit_constant_: Callable = None, + f_visit_tuple_: Callable = None, + f_visit_var_: Callable = None, + f_visit_dataflow_var_: Callable = None, + f_visit_shape_expr_: Callable = None, + f_visit_runtime_dep_shape_: Callable = None, + f_visit_extern_func_: Callable = None, + f_visit_global_var_: Callable = None, + f_visit_function_: Callable = None, + f_visit_call_: Callable = None, + f_visit_seq_expr_: Callable = None, + f_visit_if_: Callable = None, + f_visit_op_: Callable = None, + f_visit_tuple_getitem_: Callable = None, + f_visit_binding: Callable = None, + f_visit_var_binding_: Callable = None, + f_visit_match_shape_: Callable = None, + f_visit_binding_block: Callable = None, + f_visit_binding_block_: Callable = None, + f_visit_dataflow_block_: Callable = None, + f_visit_var_def: Callable = None, + f_visit_var_def_: Callable = None, + f_visit_dataflow_var_def_: Callable = None, + f_visit_type: Callable = None, + f_visit_span: Callable = None, + ) -> None: + """Constructor.""" + + self.__init_handle_by_constructor__( + _ffi_api.MakePyExprMutator, + builder, + f_visit_expr, + f_visit_constant_, + f_visit_tuple_, + f_visit_var_, + f_visit_dataflow_var_, + f_visit_shape_expr_, + f_visit_runtime_dep_shape_, + f_visit_extern_func_, + f_visit_global_var_, + f_visit_function_, + f_visit_call_, + f_visit_seq_expr_, + f_visit_if_, + f_visit_op_, + f_visit_tuple_getitem_, + f_visit_binding, + f_visit_var_binding_, + f_visit_match_shape_, + f_visit_binding_block, + f_visit_binding_block_, + f_visit_dataflow_block_, + f_visit_var_def, + f_visit_var_def_, + f_visit_dataflow_var_def_, + f_visit_type, + f_visit_span, + ) + + def visit_expr(self, expr: Expr) -> Expr: + """Generic dispatcher for Expr. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation. + """ + return _ffi_api.PyExprMutatorVisitExpr(self, expr) + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + return _ffi_api.PyExprMutatorVisitBinding(self, binding) + + def visit_binding_block(self, block: BindingBlock) -> BindingBlock: + """Generic dispatcher for BindingBlock. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation. + """ + return _ffi_api.PyExprMutatorVisitBindingBlock(self, block) + + def visit_var_def(self, var: Var) -> Var: + """Generic dispatcher for visiting the var definition site. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + return _ffi_api.PyExprMutatorVisitVarDef(self, var) + + +class PyExprMutator: + """ + An abstract ExprMutator with customized methods on the python-side. + This is the user facing class for method overwriting inheritance. + _tvm_metadata discribes the class to inherit("cls"), the methods that users can + overwrite("methods"), the constructor's parameters("fields") + + Note: @relax.expr_functor.mutator is required for proper usage of any inherited class. + + See also: visitor, _PyExprVisitor + + Example: + @relax.expr_functor.mutator + def MyExprMutator(PyExprMutator): + ... + """ + + _tvm_metadata = { + "cls": _PyExprMutator, + "fields": ["builder_"], + "methods": [ + "visit_expr", + "visit_constant_", + "visit_tuple_", + "visit_var_", + "visit_dataflow_var_", + "visit_shape_expr_", + "visit_runtime_dep_shape_", + "visit_extern_func_", + "visit_global_var_", + "visit_function_", + "visit_call_", + "visit_seq_expr_", + "visit_if_", + "visit_op_", + "visit_tuple_getitem_", + "visit_binding", + "visit_var_binding_", + "visit_match_shape_", + "visit_binding_block", + "visit_binding_block_", + "visit_dataflow_block_", + "visit_var_def", + "visit_var_def_", + "visit_dataflow_var_def_", + "visit_type", + "visit_span", + ], + } + + def __init__(self, mod: Optional[IRModule] = None) -> None: + """Constructor""" + self.builder_ = BlockBuilder(mod) + + def visit_expr(self, expr: Expr) -> Expr: + """Generic dispatcher for Expr. + Users can customized this function to overwrite VisitExpr(const Expr& expr) on the C++ side. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitExpr(self._outer(), expr) + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + Users can customized this function to overwrite VisitBinding(const Binding& binding) + on the C++ side. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitBinding(self._outer(), binding) + + def visit_binding_block(self, block: BindingBlock) -> BindingBlock: + """Generic dispatcher for BindingBlock. + Users can customized this function to overwrite VisitBindingBlock(const BindingBlock& block) + on the C++ side. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitBindingBlock(self._outer(), block) + + def visit_var_def(self, var: Var) -> Var: + """Generic dispatcher for visiting the var definition site. + Users can customized this function to overwrite VisitVarDef(const Var& var) on the C++ side. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + + Returns + ------- + result: Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitVarDef(self._outer(), var) + + def visit_constant_(self, op: Constant) -> Expr: + """Visit Constant. + Users can customized this function to overwrite VisitExpr_(const ConstantNode* op) + on the C++ side. + + Parameters + ---------- + op : Constant + The Constant to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_tuple_(self, op: Tuple) -> Expr: + """Visit Tuple. + Users can customized this function to overwrite VisitExpr_(const TupleNode* op) + on the C++ side. + + Parameters + ---------- + op : Tuple + The Tuple to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_var_(self, op: Var) -> Expr: + """Visit Var. + Users can customized this function to overwrite VisitExpr_(const VarNode* op) + on the C++ side. + + Parameters + ---------- + op : Var + The Var to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_dataflow_var_(self, op: DataflowVar) -> Expr: + """Visit DataflowVar. + Users can customized this function to overwrite VisitExpr_(const DataflowVarNode* op) + on the C++ side. + + Parameters + ---------- + op : DataflowVar + The DataflowVar to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_shape_expr_(self, op: ShapeExpr) -> Expr: + """Visit ShapeExpr. + Users can customized this function to overwrite VisitExpr_(const ShapeExprNode* op) + on the C++ side. + + Parameters + ---------- + op : ShapeExpr + The ShapeExpr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> Expr: + """Visit RuntimeDepShape. + Users can customized this function to overwrite VisitExpr_(const RuntimeDepShapeNode* op) + on the C++ side. + + Parameters + ---------- + op : RuntimeDepShape + The RuntimeDepShape to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_extern_func_(self, op: ExternFunc) -> Expr: + """Visit ExternFunc. + Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) + on the C++ side. + + Parameters + ---------- + op : ExternFunc + The ExternFunc to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_global_var_(self, op: GlobalVar) -> Expr: + """Visit GlobalVar. + Users can customized this function to overwrite VisitExpr_(const GlobalVarNode* op) + on the C++ side. + + Parameters + ---------- + op : GlobalVar + The GlobalVar to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_function_(self, op: Function) -> Expr: + """Visit Function. + Users can customized this function to overwrite VisitExpr_(const FunctionNode* op) + on the C++ side. + + Parameters + ---------- + op : Function + The Function to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_call_(self, op: Call) -> Expr: + """Visit Call. + Users can customized this function to overwrite VisitExpr_(const CallNode* op) + on the C++ side. + + Parameters + ---------- + op : Call + The Call to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_seq_expr_(self, op: SeqExpr) -> Expr: + """Visit SeqExpr. + Users can customized this function to overwrite VisitExpr_(const SeqExprNode* op) + on the C++ side. + + Parameters + ---------- + op : SeqExpr + The SeqExpr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_if_(self, op: If) -> Expr: + """Visit If. + Users can customized this function to overwrite VisitExpr_(const IfNode* op) + on the C++ side. + + Parameters + ---------- + op : If + The If to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_op_(self, op: Op) -> Expr: + """Visit Op. + Users can customized this function to overwrite VisitExpr_(const OpNode* op) + on the C++ side. + + Parameters + ---------- + op : Op + The Op to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: + """Visit TupleGetItem. + Users can customized this function to overwrite VisitExpr_(const TupleGetItemNode* op) + on the C++ side. + + Parameters + ---------- + op : TupleGetItem + The TupleGetItem to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Visit VarBinding. + Users can customized this function to overwrite VisitBinding_(const VarBindingNode* binding) + on the C++ side. + + Parameters + ---------- + binding : VarBinding + The VarBinding to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) + + def visit_match_shape_(self, binding: MatchShape) -> None: + """Visit MatchShape. + Users can customized this function to overwrite VisitBinding_(const MatchShapeNode* binding) + on the C++ side. + + Parameters + ---------- + binding : MatchShape + The MatchShape to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) + + def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: + """Visit BindingBlock. + Users can customized this function to overwrite VisitBindingBlock_(const BindingBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : BindingBlock + The BindingBlock to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBindingBlock(self._outer(), block) + + def visit_dataflow_block_(self, block: DataflowBlock) -> BindingBlock: + """Visit DataflowBlock. + Users can customized this function to overwrite VisitBindingBlock_(const DataflowBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : DataflowBlock + The DataflowBlock to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBindingBlock(self._outer(), block) + + def visit_var_def_(self, var: Var) -> Var: + """Visit the Var definition site. + Users can customized this function to overwrite VisitVarDef_(const VarNode* var) + on the C++ side. + + Parameters + ---------- + var : Var + The Var to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitVarDef(self._outer(), var) + + def visit_dataflow_var_def_(self, var: DataflowVar) -> Var: + """Visit the DataflowVar definition site. + Users can customized this function to overwrite VisitVarDef_(const DataflowVarNode* var) + on the C++ side. + + Parameters + ---------- + var : DataflowVar + The DataflowVar to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitVarDef(self._outer(), var) + + def visit_type(self, t: Type) -> Type: + """Visit Type. + Users can customized this function to overwrite VisitType(const Type& t) on the C++ side. + + Parameters + ---------- + t : Type + The Type to be visited. + + Returns + ------- + result : Type + The type after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitType(self._outer(), t) + + def visit_span(self, span: Span) -> Span: + """Visit Span. + Users can customized this function to overwrite VisitSpan(const Span& span) on the C++ side. + + Parameters + ---------- + span : Span + The Span to be visited. + + Returns + ------- + result : Span + The span after transformation. + """ + raise NotImplementedError + + def visit_expr_post_order(self, expr: Expr) -> Expr: + """Post-order rewrite an Expr and normalize. + + Parameters + ---------- + expr : Expr + The Expr to be rewritten. + + Returns + ------- + result : Expr + The Expr after post-order rewritten. + """ + return _ffi_api.PyExprMutatorVisitExprPostOrder(self._outer(), expr) + + def set_var_remap(self, vid: Id, var: Var) -> None: + """Remap a var to a new var in use-site. + + Parameters + ---------- + vid : Id + The vid of the old var. + var : Var + The new var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorSetVarRemap(self._outer(), vid, var) + + def get_var_remap(self, vid: Id) -> Var: + """Remap a var to a new var in use-site. + + Parameters + ---------- + vid : Id + The vid of the old var + + Returns + ------- + var : Var + The remapped var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorGetVarRemap(self._outer(), vid) + + def visit_with_new_scope(self, expr: Expr) -> Expr: + """Rewrite the expr with a new scope, used in a Function's body and the branches of If. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + var : Var + The expr after visiting. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitWithNewScope(self._outer(), expr) + + def lookup_binding(self, var: Var) -> Optional[Expr]: + """Look up the value bound to a variable. + Note: For function parameters, this function returns NullOpt. + + Parameters + ---------- + var : Var + The var to be looked up. + + Returns + ------- + var : Var + The value bound to the input var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorLookupBinding(self._outer(), var) + + def with_shape_and_type(self, var: Var, shape: Optional[Object], t: Type) -> Var: + """Create a new var with specified shape and type if the original var's shape or type does + not match with the specified ones. + + Parameters + ---------- + var : Var + The var to be updated. + shape : Optional[Object] + The specified shape. + t : Type + The specified type. + + Returns + ------- + var : Var + The var filled with shape and type. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorWithShapeAndType(self._outer(), var, shape, t) diff --git a/python/tvm/relax/ir/instrument.py b/python/tvm/relax/ir/instrument.py new file mode 100644 index 0000000000..fc51a796a7 --- /dev/null +++ b/python/tvm/relax/ir/instrument.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common relax pass instrumentation across IR variants.""" +import tvm +from tvm import relax + + +@tvm.instrument.pass_instrument +class WellFormedInstrument: + """An instrument that checks the input/output IRModule of the Pass + is well formed. It will skip specific passes, like Normalize. + """ + + def __init__(self): + self.skip_pass_name = ["Normalize", "ResolveGlobals"] + + def run_before_pass(self, mod, pass_info): + if pass_info.name not in self.skip_pass_name: + assert relax.analysis.well_formed(mod) + + def run_after_pass(self, mod, pass_info): + if pass_info.name not in self.skip_pass_name: + assert relax.analysis.well_formed(mod) diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py new file mode 100644 index 0000000000..7428ea590b --- /dev/null +++ b/python/tvm/relax/op/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax core operators.""" + +# Operators +from .base import * +from .tensor import * +from .op_attrs import * +from . import builtin diff --git a/python/tvm/relax/op/_ffi_api.py b/python/tvm/relax/op/_ffi_api.py new file mode 100644 index 0000000000..8dc6a1b4fb --- /dev/null +++ b/python/tvm/relax/op/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.op""" +import tvm._ffi + +tvm._ffi._init_api("relax.op", __name__) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py new file mode 100644 index 0000000000..c367d630de --- /dev/null +++ b/python/tvm/relax/op/base.py @@ -0,0 +1,311 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# pylint: disable=redefined-builtin +"""The base Relax operators.""" +from typing import Union, List, Optional + +import tvm +from tvm.runtime.object import Object + +from . import _ffi_api +from ..expr import Expr, ShapeExpr, Tuple, Call, ExternFunc +from ..ty import DynTensorType, TupleType +from ...ir import Array + +py_print = print # pylint: disable=invalid-name + + +def call_tir( + func: Union[str, Expr], + args: Union[Expr, Tuple, List[Expr]], + shape: Union[Tuple, ShapeExpr, List[int]], + dtype: Union[str, List[str]], + tir_vars: Optional[ShapeExpr] = None, +) -> Call: + """ + Call a destination-passing-style function and return the output. + + Parameters + ---------- + func : Union[str, Expr] + The destination-passing-style function, can be ExternFunc or PrimFunc. + + args : Union[Expr, Tuple, List[Expr]] + The input arguments. + + shape: Union[Tuple, ShapeExpr, List[int]] + The output shape. Tuple(ShapeExpr) if multiple outputs, ShapeExpr if single output. + + dtype: Union[str, List[str]] + The output dtype. List[str] if multiple outputs, str if single output. + + tir_vars : ShapeExpr, optional + ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used + + Returns + ------- + ret: Call + A call node for the call_tir operator. + """ + if isinstance(func, str): + func = ExternFunc(func) + + if isinstance(shape, (list, tuple, Array)): + shape = ShapeExpr(shape) + + if isinstance(args, Expr): + args = Tuple((args,)) + + if isinstance(args, (list, tuple)): + args = Tuple(args) + + if isinstance(dtype, str): + output_type = DynTensorType(len(shape), dtype) + elif isinstance(dtype, (list, tuple)): + if len(shape) != len(dtype): + raise ValueError("The number of output_shape and output_dtype of call_tir mismatch") + output_type = TupleType([DynTensorType(len(x), y) for x, y in zip(shape, dtype)]) + else: + raise TypeError("Not supported dtype for call_tir: " + str(type(dtype))) + + return _ffi_api.call_tir(func, args, shape, output_type, tir_vars) + + +def make_closure( + func: Expr, + args: Union[Tuple, List[Expr]], +) -> Object: + """ + Create a closure with free variables and return the closure. + + Parameters + ---------- + func : Expr + The closure, can be ExternFunc or PrimFunc. + + args : Union[Tuple, List[Expr]] + The input arguments. + + + Returns + ------- + ret: Object + The VMClosure. + """ + + if isinstance(args, (list, tuple)): + args = Tuple(args) + + return _ffi_api.make_closure(func, args) + + +def invoke_closure( + closure: Expr, + args: Union[Tuple, List[Expr]], +) -> Object: + """ + Invoke a closure. + + Parameters + ---------- + closure : Expr + The VMClosure object. + + args : Union[Tuple, List[Expr]] + The input arguments. + + + Returns + ------- + ret: Object + The result. + """ + + if isinstance(args, (list, tuple)): + args = Tuple(args) + + return _ffi_api.invoke_closure(closure, args) + + +def render_object(val: tvm.Object) -> str: + """ + Given a TVM Object, renders it in string form. Used for Relax printing and assertions. + + Parameters + ---------- + val: tvm.Object + An object to render + + Returns + ------- + ret: str + A string representing the value, ideally human-readable + """ + if isinstance(val, tvm.runtime.ndarray.NDArray): + return str(val) + # no pretty-printer by default, so if we don't handle this, + # then we can't look inside tuples + if isinstance(val, tvm.runtime.container.ADT): + # the fields array of an ADT cannot be directly accessed in Python + # so we have to get the length and index into the fields separately + fields = ", ".join([render_object(val[i]) for i in range(len(val))]) + # special case: tag = 0 is a tuple + if val.tag == 0: + return f"({fields})" + return f"ADT(tag={val.tag}, fields=[{fields}])" + return str(val) + + +@tvm.register_func("relax.run.print") +def relax_print(format_str: str, *format_args: tvm.Object) -> None: + """ + Takes a list of values to print, formats with the given format string. + If the format string is empty, simply prints. + + Call from TVM script like this: + `relax.print(value1, value2, ..., valueN, format=format_str)` + or + `relax.print(value1, value2, ..., valueN) # format_str defaults to ""` + + Parameters + ---------- + format_str: str + The last argument is a Python-style format string for printing the value + + format_args: List[Object] + The values to print. + """ + val_strs = map(render_object, format_args) + if format_str == "": + py_print(*val_strs) + else: + py_print(format_str.format(*val_strs)) + + +def print(values: Union[Expr, List[Expr]], format: str) -> Expr: + """Print op to print the values + + Parameters + ---------- + values : List[Expr] + The values to print. + + format_str: str + The format string. + + Returns + ------- + result : Expr + A relax Call, which will print the value during runtime. + """ + if isinstance(values, Expr): + values = [values] + return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member + + +@tvm.register_func("relax.run.assert_op") +def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Object) -> None: + """ + A variadic function. The first value serves as the assertion condition: + If the condition is true, then the operator does nothing. + If the condition is false, then the operator raises an assertion error. + + Arguments after the first value serve as format arguments for the error message; + the last argument must be a format string for the error message (empty by default). + If the format string is the empty string, then the error message will simply include + a comma-separated list of the format arguments. + The condition argument is not included in the format string. + + Parameters + ---------- + condition: tvm.Object + The assertion condition. Must be a boolean scalar. + + format_str: str + The last argument is a Python-style format string for printing the value + + format_args: List[tvm.Object] + Values used for formatting the string. + """ + if not isinstance(format_str, str): + raise ValueError( + f"The format string argument to assert must be a string, given {type(format_str)})" + ) + + # should be guaranteed by the type system + if not isinstance(condition, tvm.runtime.ndarray.NDArray): + raise ValueError(f"The condition must be an NDArray, but given a {type(condition)}.") + + # may happen if the original program had unknown shape or dtype for the tensor's type + dtype = condition.dtype + if dtype != "bool": + raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor") + shape = condition.shape + if len(shape) != 0: + raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}") + + val = condition.numpy() + if not val: + error_message = "Assertion Failed" + if format_args or format_str != "": + rendered = map(render_object, format_args) + if format_str != "": + error_message = format_str.format(*rendered) + else: + error_message = ", ".join(rendered) + raise AssertionError(error_message) + + +def assert_op(condition: Expr, format_args: Optional[List[Expr]] = None, format: str = "") -> Expr: + """ + Create a call to Relax's assert_op operation (`assert` is reserved in Python, + so the name must be distinct). + + Parameters + ---------- + condition: Expr + The assertion condition. + + format_args: List[Expr] + Format arguments for the error message if the condition fails. + + format_str: str + The format string for the error message. + + Returns + ------- + result : Expr + A Call to the Relax assert operation. + """ + if format_args is None: + format_args = [] + return _ffi_api.assert_op(condition, format_args, format) # type: ignore + + +def shape_of(expr: Expr) -> Expr: + """Get shape of a tensor. + + Parameters + ---------- + expr : Expr + The input Expr. + + Returns + ------- + result : Expr + A relax Call, which gets the shape of the input + """ + return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/relax/op/builtin/__init__.py b/python/tvm/relax/op/builtin/__init__.py new file mode 100644 index 0000000000..04837724b1 --- /dev/null +++ b/python/tvm/relax/op/builtin/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax builtin operators.""" + +from .builtin import * diff --git a/python/tvm/relax/op/builtin/_ffi_api.py b/python/tvm/relax/op/builtin/_ffi_api.py new file mode 100644 index 0000000000..42fe8cb652 --- /dev/null +++ b/python/tvm/relax/op/builtin/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.op.builtin""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.builtin", __name__) diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py new file mode 100644 index 0000000000..0c80ba73d6 --- /dev/null +++ b/python/tvm/relax/op/builtin/builtin.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""The builtin Relax operators.""" + +from typing import List, Union +from tvm.ir.expr import PrimExpr +from . import _ffi_api +from ...expr import ShapeExpr, Call + + +# TODO(relax-team): add documents +def alloc_tensor( + shape: Union[ShapeExpr, PrimExpr, List[PrimExpr]], dtype: str, runtime_device_index: int +) -> Call: + if not isinstance(shape, ShapeExpr): + if not isinstance(shape, (tuple, list)): + shape = (shape,) + shape = ShapeExpr(shape) + return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py new file mode 100644 index 0000000000..f8ef3107de --- /dev/null +++ b/python/tvm/relax/op/op_attrs.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The attributes node used for Relax operators""" +from tvm.ir import Attrs +import tvm._ffi + + +@tvm._ffi.register_object("relax.attrs.AllocTensorAttrs") +class AllocTensorAttrs(Attrs): + """Attributes used in alloc_tensor operators""" + + +@tvm._ffi.register_object("relax.attrs.VMAllocStorageAttrs") +class VMAllocStorageAttrs(Attrs): + """Attributes used in VM alloc_storage operators""" + + +@tvm._ffi.register_object("relax.attrs.VMAllocTensorAttrs") +class VMAllocTensorAttrs(Attrs): + """Attributes used in VM alloc_tensor operators""" + + +@tvm._ffi.register_object("relax.attrs.UniqueAttrs") +class UniqueAttrs(Attrs): + """Attributes used for the unique operator""" + + +@tvm._ffi.register_object("relax.attrs.PrintAttrs") +class PrintAttrs(Attrs): + """Attributes used for the print operator""" + + +@tvm._ffi.register_object("relax.attrs.AssertOpAttrs") +class AssertOpAttrs(Attrs): + """Attributes used for the assert operator""" diff --git a/python/tvm/relax/op/tensor.py b/python/tvm/relax/op/tensor.py new file mode 100644 index 0000000000..3d48973f01 --- /dev/null +++ b/python/tvm/relax/op/tensor.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# pylint: disable=redefined-builtin +"""Basic tensor operations.""" +import numpy as np +import tvm + +from . import _ffi_api +from ..expr import Expr + + +def add(lhs: Expr, rhs: Expr) -> Expr: + return _ffi_api.add(lhs, rhs) + + +def multiply(lhs: Expr, rhs: Expr) -> Expr: + return _ffi_api.multiply(lhs, rhs) + + +def unique( + data: Expr, + sorted: bool = True, + return_inverse: bool = False, + return_counts: bool = False, + dim: int = -1, +) -> Expr: + """Find the unique elements and the new index of each item in a given tensor. + + Parameters + ---------- + data : Expr + The input tensor. + + sorted: bool + Whether to sort the unique elements in ascending order before + returning as output. + + return_inverse: bool + Whether to return an additional tensor with indices for where elements in + the original input ended up in the returned unique list. + + return_counts: bool + Whether to return an additional tensor with counts of each unique elements. + + dim: int + The dimension to apply unique. If negative, the unique of the flattened input is returned. + + Returns + ------- + ret: Expr + The created relax call with + """ + + return _ffi_api.unique(data, sorted, return_inverse, return_counts, dim) + + +@tvm.register_func("relax.run.unique") +def numpy_unique( + a: tvm.nd.array, + sort: int, + return_inverse: int, + return_counts: int, + dim: int, +) -> tvm.nd.array: + """Returns the unique elements of the input tensor. + + Uses numpy.unique to compute unique elements. + """ + # TODO(prakalp): add support for returning a tuple when return_inverse or return_counts is True + if bool(return_inverse) or bool(return_counts): + raise NotImplementedError("missing support return_inverse or return_counts set to true") + if dim < 0: + dim = None + a_numpy = a.numpy() + # TODO(prakalp): use torch.unique instead of numpy when torch is installed in ci. + output_sorted_numpy, indices = np.unique(a_numpy, return_index=True) + if sort: + return tvm.nd.array(output_sorted_numpy) + output_numpy = [a_numpy.flatten()[index] for index in sorted(indices, reverse=True)] + return tvm.nd.array(output_numpy) diff --git a/python/tvm/relax/testing/__init__.py b/python/tvm/relax/testing/__init__.py new file mode 100644 index 0000000000..a6e3a94251 --- /dev/null +++ b/python/tvm/relax/testing/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""The Relax testing namespace containing nn and translator.""" + +from .nn import * +from .relay_translator import * +from .ast_printer import dump_ast diff --git a/python/tvm/relax/testing/_ffi_api.py b/python/tvm/relax/testing/_ffi_api.py new file mode 100644 index 0000000000..5631489989 --- /dev/null +++ b/python/tvm/relax/testing/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI API for for Relax.""" +import tvm._ffi + +tvm._ffi._init_api("relax", __name__) diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py new file mode 100644 index 0000000000..101821544c --- /dev/null +++ b/python/tvm/relax/testing/ast_printer.py @@ -0,0 +1,317 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, abstract-method, arguments-differ +""" +Utility script for printing Relax modules as AST diagrams, +only intended to show how the AST is put together. +It is not a pretty-printer and, in fact, is more of an ugly-printer, +but it can be useful for tutorials and debugging. +""" +from __future__ import annotations # must import to defer parsing of annotations +from typing import Dict, Iterable +import tvm +from tvm import relax +from tvm.ir.expr import PrimExpr +from tvm.relax import ExprFunctor + + +def wrap_quotes(text: str) -> str: + """ + Wraps the text in quotes. + """ + return f'"{text}"' + + +class ASTPrinter(ExprFunctor): + """ + Class for recursing down ASTs and printing them in a very simple format, + mainly for instructive purposes and, perhaps, debugging. + """ + + def __init__( + self, + indent_str=" ", + include_type_annotations=True, + include_shape_annotations=True, + include_call_attrs=True, + ): + self.indent_str = indent_str + self.include_type_annotations = include_type_annotations + self.include_shape_annotations = include_shape_annotations + self.include_call_attrs = include_call_attrs + + def visit_expr(self, expr: relax.Expr) -> str: + # extend so we also dispatch to bindings and binding blocks, + # a little silly but IRFunctor hasn't been ported to Python + if isinstance(expr, relax.DataflowBlock): + return self.visit_dataflow_block_(expr) + if isinstance(expr, relax.BindingBlock): + return self.visit_binding_block_(expr) + if isinstance(expr, relax.Binding): + return self.visit_binding_(expr) + return super().visit_expr(expr) + + def indent(self, text: str) -> str: + """ + Indent all lines of the input. + """ + if text == "": + return "" + lines = text.split("\n") + return self.indent_str + f"\n{self.indent_str}".join(lines) + + def build_ast_node(self, nodename: str, force_newline=False, **kwargs: Dict[str, str]): + """ + Returns 'nodename(..., fields[i][0]=fields[i][1], ...)' + with appropriate indentation + """ + return self.build_list( + map(lambda field: f"{field[0]}={field[1]}", kwargs.items()), + open_tok=f"{nodename}(", + close_tok=")", + force_newline=force_newline, + ) + + def build_list( + self, members: Iterable[str], open_tok="[", close_tok="]", force_newline=False + ) -> str: + """ + Builds a list of the members given, appropriately indented, + with each field on a line. + (special case: if there is only one field, then we do not put it on a new line + unless that field contains a newline or `force_newline` is set to true). + `open_tok` and `close_tok` are used to open and close the list, respectively. + """ + mem_list = list(members) + if not mem_list: + return f"{open_tok}{close_tok}" + if len(mem_list) == 1 and not force_newline and "\n" not in mem_list[0]: + return f"{open_tok}{mem_list[0]}{close_tok}" + member_lines = ",\n".join(map(self.indent, mem_list)) + return f"{open_tok}\n{member_lines}\n{close_tok}" + + def visit_constant_(self, op: relax.Constant) -> str: + # simple rule of thumb: keep scalars inline, but anything larger goes on a new one + force_newline = len(op.data.shape) > 0 + return self.build_ast_node("Constant", force_newline=force_newline, data=str(op.data)) + + def visit_tuple_(self, op: relax.Tuple) -> str: + return self.build_ast_node("Tuple", fields=self.build_list(map(self.visit_expr, op.fields))) + + def visit_dataflow_var_(self, op: relax.DataflowVar) -> str: + fields = {"name_hint": wrap_quotes(op.name_hint)} + if op.shape_ and self.include_shape_annotations: + fields["shape_"] = self.visit_expr(op.shape_) + if op._checked_type_ and self.include_type_annotations: + fields["_checked_type_"] = self.visit_type_(op._checked_type_) + return self.build_ast_node("DataflowVar", **fields) + + def visit_var_(self, op: relax.Var) -> str: + fields = {"name_hint": wrap_quotes(op.name_hint)} + if op.shape_ and self.include_shape_annotations: + fields["shape_"] = self.visit_expr(op.shape_) + if op._checked_type_ and self.include_type_annotations: + fields["_checked_type_"] = self.visit_type_(op._checked_type_) + return self.build_ast_node("Var", **fields) + + def visit_shape_expr_(self, op: relax.ShapeExpr) -> str: + return self.build_ast_node( + "ShapeExpr", values=self.build_list(map(self.visit_prim_expr_, op.values)) + ) + + def visit_runtime_dep_shape_(self, _: relax.RuntimeDepShape) -> str: + # no fields, apparently? + return self.build_ast_node("RuntimeDepShape") + + def visit_extern_func_(self, op: relax.ExternFunc) -> str: + return self.build_ast_node("ExternFunc", global_symbol=wrap_quotes(op.global_symbol)) + + def visit_global_var_(self, op: relax.GlobalVar) -> str: + return self.build_ast_node("GlobalVar", name_hint=wrap_quotes(op.name_hint)) + + def visit_function_(self, op: relax.Function) -> str: + fields = { + "params": self.build_list(map(self.visit_expr, op.params)), + "body": self.visit_expr(op.body), + "ret_shape": self.visit_expr(op.ret_shape), + } + if op.ret_type: + fields["ret_type"] = self.visit_type_(op.ret_type) + if op.attrs: + fields["attrs"] = self.build_list( + map( + lambda kv: f"{wrap_quotes(str(kv[0]))}: {wrap_quotes(str(kv[1]))}", + op.attrs.items(), + ), + open_tok="{", + close_tok="}", + ) + return self.build_ast_node("Function", **fields) + + def visit_call_(self, op: relax.Call) -> str: + fields = { + "op": self.visit_expr(op.op), + "args": self.build_list(map(self.visit_expr, op.args)), + } + if op.type_args: + fields["type_args"] = self.build_list(map(self.visit_type_, op.type_args)) + if op.attrs and self.include_call_attrs: + + def display_attrs(attr_key): + attr_val = op.attrs[attr_key] + # attrs can be strings but also other types; + # we want to wrap strings in quotes + # (__repr__ would work but it uses single quotes) + attr_str = wrap_quotes(attr_val) if isinstance(attr_val, str) else str(attr_val) + return f"{wrap_quotes(attr_key)}: {attr_str}" + + fields["attrs"] = self.build_list( + map(display_attrs, op.attrs.keys()), + open_tok="{", + close_tok="}", + ) + return self.build_ast_node("Call", **fields) + + def visit_seq_expr_(self, op: relax.SeqExpr) -> str: + return self.build_ast_node( + "SeqExpr", + blocks=self.build_list(map(self.visit_binding_block_, op.blocks)), + body=self.visit_expr(op.body), + ) + + def visit_if_(self, op: relax.If) -> str: + return self.build_ast_node( + "If", + cond=self.visit_expr(op.cond), + true_branch=self.visit_expr(op.true_branch), + false_branch=self.visit_expr(op.false_branch), + ) + + def visit_op_(self, op: tvm.ir.Op) -> str: + # TODO: List other attributes? + return self.build_ast_node("Op", name=wrap_quotes(op.name)) + + def visit_prim_expr_(self, prim_expr: PrimExpr) -> str: + # TODO: We may want to print PrimExpr ASTs, but this is a simplification for now + return self.build_ast_node("PrimExpr", value=f"`{str(prim_expr)}`") + + def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> str: + return self.build_ast_node( + "TupleGetItem", + tuple_value=self.visit_expr(op.tuple_value), + index=str(op.index), + ) + + def visit_type_(self, type_node: relax.Type) -> str: + """ + Recurse down types and print their ASTs too + """ + if isinstance(type_node, relax.ShapeType): + return self.build_ast_node("ShapeType") + if isinstance(type_node, relax.ObjectType): + return self.build_ast_node("ObjectType") + if isinstance(type_node, relax.DynTensorType): + fields = {} + if type_node.ndim is not None: + fields["ndim"] = str(type_node.ndim) + if type_node.dtype != "": + fields["dtype"] = type_node.dtype + return self.build_ast_node("DynTensorType", **fields) + if isinstance(type_node, relax.DimType): + return self.build_ast_node("DimType") + if isinstance(type_node, relax.TupleType): + return self.build_ast_node( + "TupleType", fields=self.build_list(map(self.visit_type_, type_node.fields)) + ) + if isinstance(type_node, relax.FuncType): + return self.build_ast_node( + "FuncType", + arg_types=self.build_list(map(self.visit_type_, type_node.arg_types)), + ret_type=self.visit_type_(type_node.ret_type), + # TODO: skipping type params and type constraints + ) + raise ValueError(f"Invalid Relax Type {type_node} ({type(type_node)})") + + def visit_binding_block_(self, block: relax.BindingBlock) -> str: + """ + Recurse down binding blocks + """ + return self.build_ast_node( + "BindingBlock", + bindings=self.build_list(map(self.visit_binding_, block.bindings), force_newline=True), + ) + + def visit_dataflow_block_(self, block: relax.DataflowBlock) -> str: + """ + Recurse down a dataflow block + """ + return self.build_ast_node( + "DataflowBlock", + bindings=self.build_list(map(self.visit_binding_, block.bindings), force_newline=True), + ) + + def visit_binding_(self, binding: relax.Binding) -> str: + """ + Distinguish between binding types + """ + if isinstance(binding, relax.MatchShape): + return self.visit_match_shape_(binding) + if isinstance(binding, relax.VarBinding): + return self.visit_var_binding_(binding) + raise ValueError(f"Invalid binding type in {binding}: {type(binding)}") + + def visit_match_shape_(self, match_shape: relax.MatchShape) -> str: + """ + Handle match shape + """ + return self.build_ast_node( + "MatchShape", + value=self.visit_expr(match_shape.value), + pattern=self.build_list(map(self.visit_prim_expr_, match_shape.pattern)), + var=self.visit_expr(match_shape.var), + ) + + def visit_var_binding_(self, var_binding: relax.VarBinding) -> str: + """ + Handle ordinary var bindings + """ + return self.build_ast_node( + "VarBinding", + var=self.visit_expr(var_binding.var), + value=self.visit_expr(var_binding.value), + ) + + +def dump_ast( + exp: relax.Expr, + indent_str=" ", + include_type_annotations=True, + include_shape_annotations=True, + include_call_attrs=True, +) -> str: + """ + Dump an AST in a text format. + Can vary the indentation string and choose whether to include + type and shape annotations or call attributes. + """ + printer = ASTPrinter( + indent_str=indent_str, + include_type_annotations=include_type_annotations, + include_shape_annotations=include_shape_annotations, + include_call_attrs=include_call_attrs, + ) + return printer.visit_expr(exp) diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py new file mode 100644 index 0000000000..e20a24b838 --- /dev/null +++ b/python/tvm/relax/testing/nn.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +"""PyTorch-like nn.Module API for constructing workloads.""" + + +from typing import List, Any, Callable +import numpy as np + +import tvm +from tvm import relax, topi, tir + + +def emit_te(func: Callable, *args: Any, **kwargs: Any) -> relax.Var: + return relax.BlockBuilder.current().emit_te(func, *args, **kwargs) + + +class Placeholder(relax.Var): + """A placeholder variable that can represent model input.""" + + def __init__(self, shape, dtype="float32", name="data"): + if not isinstance(shape, (list, tuple)): + raise TypeError("the shape of Placeholder is expected to be a list or a tuple") + ndim = len(shape) + type_anno = relax.DynTensorType(ndim, dtype) + super().__init__(relax.BlockBuilder.current().get_unique_name(name), shape, type_anno) + + +class Parameter(relax.Var): + """A special kind of relax Var that represents model parameter(weight).""" + + def __init__(self, shape, dtype="float32", name="param"): + if not isinstance(shape, (list, tuple)): + raise TypeError("the shape of Parameter is expected to be a list or a tuple") + ndim = len(shape) + type_anno = relax.DynTensorType(ndim, dtype) + super().__init__(relax.BlockBuilder.current().get_unique_name(name), shape, type_anno) + + +class Module: + """Base class for all model modules. + + A neural network or a layer can subclass this class. + + Example + ------- + .. code-block:: python + + # Define a linear layer + class Linear(Module) + def __init__(self, in_features, out_features, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter((in_features, out_features), name="linear_weight") + if bias: + self.bias = Parameter((out_features,), name="linear_bias") + else: + self.bias = None + + # All submodules should implement forward. + # Defines the forward computation performed at every call. + def forward(self, input: relax.Expr) -> relax.Var: + y = emit_te(topi.matmul, input, self.weight) + if self.bias is not None: + y = emit_te(topi.add, y, self.bias) + return y + """ + + def parameters(self) -> List[Parameter]: + """Return the list of parameters in the module.""" + return _unpack_params(self.__dict__) + + def forward(self, input: relax.Expr): + """Define the computation performed at every call.""" + raise NotImplementedError() + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +def _unpack_params(value: object) -> List[relax.Var]: + if isinstance(value, Parameter): + return [value] + if isinstance(value, Module): + return value.parameters() + if isinstance(value, dict): + params = [] + for v in value.values(): + params += _unpack_params(v) + return params + if isinstance(value, (list, tuple)): + params = [] + for v in value: + params += _unpack_params(v) + return params + if isinstance(value, (int, float, str)): + return [] + raise TypeError("not supported type when unpacking parameters: {}".format(type(value))) + + +def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]: + """Utility function to initialize model's parameters.""" + shape_dict = {v.name_hint: v.shape_ for v in mod["main"].params} + params = [] + for k, v in shape_dict.items(): + if k.startswith("data"): + continue + if isinstance(v, relax.ShapeExpr): + shape = [] + for i in v: + if isinstance(i, tir.IntImm): + shape.append(int(i)) + else: + raise TypeError("cannot initialize for unknown-shape parameters.") + params.append(tvm.nd.array(np.zeros(shape).astype(np.float32))) + else: + raise TypeError("cannot initialize for unknown-shape parameters.") + return params + + +class Sequential(Module): + """A sequential container that concatenates modules in it. + + Example + ------- + .. code-block:: python + + model = nn.Sequential( + nn.Conv2d(1, 20, 5), + nn.ReLU(), + nn.Conv2d(20, 64, 5), + nn.ReLU() + ) + """ + + def __init__(self, *modules: Module): + self.modules = modules + + def forward(self, input: relax.Expr) -> relax.Var: + for module in self.modules: + input = module(input) + return input + + +class ReLU(Module): + """Applies the rectified linear unit activation function on the input.""" + + def forward(self, input: relax.Expr) -> relax.Var: + return emit_te(topi.nn.relu, input) + + +class LogSoftmax(Module): + """Applies log softmax activation function on the input.""" + + def forward(self, input: relax.Expr) -> relax.Var: + return emit_te(topi.nn.log_softmax, input) + + +class Linear(Module): + """Applies a linear transformation to the input data: :math:`y = xA + b`.""" + + def __init__(self, in_features, out_features, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter((in_features, out_features), name="linear_weight") + if bias: + self.bias = Parameter((out_features,), name="linear_bias") + else: + self.bias = None + + def forward(self, input: relax.Expr) -> relax.Var: + y = emit_te(topi.matmul, input, self.weight) + if self.bias is not None: + y = emit_te(topi.add, y, self.bias) + return y diff --git a/python/tvm/relax/testing/relay_translator.py b/python/tvm/relax/testing/relay_translator.py new file mode 100644 index 0000000000..7c49631709 --- /dev/null +++ b/python/tvm/relax/testing/relay_translator.py @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, invalid-name, no-else-return +"""Relay to Relax translator.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import tvm +from tvm import relax, relay +from tvm.ir.module import IRModule +from tvm.relax.testing import nn +from tvm.relay.backend.te_compiler import select_implementation +from tvm.runtime import NDArray +from tvm.target import Target +from tvm.meta_schedule.relay_integration import _autotvm_silencer + + +def from_relay( + func: relay.Function, + target: Target, + relay_params: Optional[Dict[str, NDArray]] = None, + *, + opt_level: int = 3, + pass_config: Optional[Dict[str, Any]] = None, + disabled_pass: Optional[List[str]] = None, + translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] = None, +) -> IRModule: + """Convert a Relay function into a Relax program. + + Parameters + ---------- + func : relay.Function + Relay function to be converted. + + target: Target + The target to compile the model, used for selecting topi functions. + + relay_params: Optional[Dict[str, NDArray]] + Parameters to bind. + + opt_level: int + The optimization level. + + pass_config: Optional[Dict[str, Any]] + Pass configuration. + + disabled_pass: Optional[List[str]] + Passes to disable. + + translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] + Dict that maps op names to user-defined PrimFuncs. + Takes relay operator names and forces them to user-defined PrimFuncs during translation. + + Returns + ------- + mod : tvm.IRModule + The Relax IRModule for compilation + """ + # A map to store the mapping of Relay Expr to its corresponding Relax var + var_map = {} + # The output of the function + output_var = None + + if not isinstance(target, Target): + target = Target(target) + if disabled_pass is None: + disabled_pass = [] + if pass_config is None: + pass_config = { + "relay.FuseOps.max_depth": 1, # Disable relay fusion + "relay.backend.use_meta_schedule": True, + "relay.backend.use_meta_schedule_dispatch": True, + } + + if relay_params: + func = relay.build_module.bind_params_by_name(func, relay_params) + + params = [] + + def convert_shape(shape: List[tvm.tir.PrimExpr]) -> List[tvm.tir.PrimExpr]: + """Convert the relay shape to relax shape by changing Any dim to symbolic dim""" + ret = [] + for dim in shape: + if isinstance(dim, tvm.tir.IntImm): + ret.append(tvm.tir.IntImm("int64", int(dim))) + elif isinstance(dim, tvm.tir.Any): + ret.append(tvm.tir.Var("d", "int64")) + else: + ret.append(dim) + return ret + + def visit_func(node): + nonlocal output_var + if isinstance(node, relay.Var): + if isinstance(node.type_annotation, relay.TensorType): + var_map[node] = nn.Placeholder( + tuple(convert_shape(node.type_annotation.shape)), + node.type_annotation.dtype, + node.name_hint, + ) + params.append(var_map[node]) + else: + raise TypeError("The type of relay.Var to be translated must be of TensorType.") + elif isinstance(node, relay.Call): + args = node.args + new_args = [] + te_inputs = [] + for arg in args: + if arg in var_map: + new_args.append(var_map[arg]) + te_inputs.append(tvm.relax.expr.te_tensor(new_args[-1])) + + op_name = node.op.name + attrs = node.attrs + out_type = node.checked_type + + if translate_op_with_tir and op_name in translate_op_with_tir: + tir_gvar = bb.add_func(translate_op_with_tir[op_name], op_name) + call = relax.call_tir(tir_gvar, new_args, out_type.shape, out_type.dtype) + var = bb.emit(call) + else: + with target: + best_impl, outputs = select_implementation( + node.op, + attrs, + te_inputs, + out_type, + target, + use_autotvm=False, + ) + compute_func = best_impl.compute + name_hint = op_name.split(".")[-1] + var = bb.emit_te( + compute_func, + attrs, + new_args, + node.checked_type, + primfunc_name_hint=name_hint, + ) + + output_var = var + var_map[node] = var + elif isinstance(node, relay.Constant): + # fill the shape and checked_type fields of the Constant + new_constant = relay.Constant(node.data) + var_map[node] = new_constant + elif isinstance(node, relay.Tuple): + new_fields = [] + for field in node.fields: + if field in var_map: + new_fields.append(var_map[field]) + else: + raise RuntimeError("field is not in var_map.") + new_tuple = relax.Tuple(new_fields) + new_tuple_var = relax.BlockBuilder.current().emit(new_tuple) + var_map[node] = new_tuple_var + output_var = new_tuple_var + elif isinstance(node, relay.TupleGetItem): + if node.tuple_value in var_map: + new_tuple = var_map[node.tuple_value] + new_tuple_get_item_node = relax.TupleGetItem(new_tuple, node.index) + new_tuple_get_item_var = relax.BlockBuilder.current().emit(new_tuple_get_item_node) + var_map[node] = new_tuple_get_item_var + output_var = new_tuple_get_item_var + else: + raise RuntimeError("tuple is not in var_map") + elif isinstance(node, relay.Function): + cur_bb = relax.BlockBuilder.current() + gv = cur_bb.emit_output(output_var) + df_block = cur_bb._end_block() + cur_bb._blocks.append(df_block) + cur_bb.emit_func_output(gv, params) + elif isinstance(node, tvm.ir.Op): + pass + else: + raise TypeError("{} is not supported yet.".format(str(type(node)))) + + # List of subset of relay->relay optimizations + # See src/relay/backend/utils.cc::GetPassPrefix() for full list + seq = tvm.get_global_func("relay.backend.GetPassPrefixSeq")(True, True) + + # Since optimization passes and OpStrategy are highly context-dependent, + # we match the exact same context with `extract_task_from_relay()` env + with _autotvm_silencer(), tvm.transform.PassContext( + opt_level=opt_level, + config=pass_config, + disabled_pass=disabled_pass, + ): + mod = tvm.IRModule.from_expr(func) + mod = seq(mod) + bb = relax.BlockBuilder() + with bb.function("main"): + bb._begin_dataflow_block() + relay.analysis.post_order_visit(mod["main"], visit_func) + + return bb.get() diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py new file mode 100644 index 0000000000..c26b15c860 --- /dev/null +++ b/python/tvm/relax/testing/transform.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ +"""Relax transformation passes for testing""" + +from __future__ import annotations +from tvm import ir +from tvm import relax +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.target import Target +from tvm.ir import transform +from tvm.relax import PyExprMutator +from tvm.relax.expr import Call +from tvm.relay.backend.te_compiler import select_implementation + + +@ir.transform.module_pass(opt_level=0) +class LowerWithRelayOpStrategyPass(transform.Pass): + """Lower Relax Op into TIR by using Relay OpStrategy. + + Since operators like conv2d, add, matmul are relay-, relax- independent, + this pass assumes we can always find relay op equivalent for such relax ops, + and use Relay Op Strategy (legacy) to perform lowering and find the TOPI implementation. + + Parameters + ---------- + target : Target + target info + + Returns + ------- + pass : transform.Pass + lowering pass + """ + + def __init__(self, target: Target): + self.target = target + + def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: + """Implement lowering mechanism. + + Parameters + ---------- + mod : IRModule + Input IRModule with Relax ops + + ctx: PassContext + Pass context + + Returns + ------- + out_mod : IRModule + Output IRModule with lowered TIR functions + """ + target = self.target + + @relax.expr_functor.mutator + class Lowerer(PyExprMutator): + """Mutator that performs lowering.""" + + def visit_call_(self, call_node: Call): + # Ignore function calls + # We only target calls for operators + if isinstance(call_node.op, (relax.GlobalVar, relax.expr.ExternFunc)): + return call_node + + # Current relax op name simply adds "relax." prefix to relay op name. + # Thus, remove "relax." prefix to deduce relay op name. + relay_op_name = call_node.op.name[6:] + # Check if equivalent relay op exists. If not, return the original call. + if relay_op_name in ir.Op.list_op_names(): + relay_op = ir.Op.get(relay_op_name) + + te_inputs = [relax.expr.te_tensor(arg) for arg in call_node.args] + best_impl_tuple = select_implementation( + relay_op, + call_node.attrs, + te_inputs, + call_node.checked_type, + target, + use_autotvm=False, + ) + compute_func = best_impl_tuple[0].compute + # Extract the name of the operator without the prefix + # e.g., for relay op "nn.conv2d", name_hint would be conv2d + name_hint = relay_op_name.split(".")[-1] + + return self.builder_.call_te( + compute_func, + call_node.attrs, + call_node.args, + call_node.attrs, + primfunc_name_hint=name_hint, + ) + else: + return call_node + + # TOOD(@team): transform() wapper is necessary to include TIR functions. + # IMO, this is bit unintuitive. Can we improve this? + def transform(self): + for gv, func in mod.functions.items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + self.builder_.update_func(gv, updated_func) + new_mod = self.builder_.get() + new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod + return new_mod + + return Lowerer().transform() diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py new file mode 100644 index 0000000000..894bfea405 --- /dev/null +++ b/python/tvm/relax/transform/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax IR analysis. """ + +from .transform import * +from .fma_rewrite import * diff --git a/python/tvm/relax/transform/_ffi_api.py b/python/tvm/relax/transform/_ffi_api.py new file mode 100644 index 0000000000..667aa62c2c --- /dev/null +++ b/python/tvm/relax/transform/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.transform""" +import tvm._ffi + +tvm._ffi._init_api("relax.transform", __name__) diff --git a/python/tvm/relax/transform/fma_rewrite.py b/python/tvm/relax/transform/fma_rewrite.py new file mode 100644 index 0000000000..ab4aeb8024 --- /dev/null +++ b/python/tvm/relax/transform/fma_rewrite.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, invalid-name, abstract-method +"""Perform fused multiply-add rewriting in Python""" +from tvm.ir import Op +from tvm.ir.module import IRModule +from tvm.ir.transform import module_pass +from ..expr_functor import mutator, PyExprMutator +from ..expr import Call, Function, Var, RuntimeDepShape +from ..transform import dataflowblock_pass + + +@mutator +class EwiseFMARewriter(PyExprMutator): + """Rewrites the relax.add call to a relax.ewise_fma call + when detecting the multiply-add pattern. + + Example + -------- + x0 = mul(a, b) + z0 = add(x0, c) + --> + z0 = ewise_fma(a, b, c) + """ + + def visit_call_(self, call: Call) -> Call: # pylint: disable=arguments-differ + call = self.visit_expr_post_order(call) + add_op = Op.get("relax.add") + multiply_op = Op.get("relax.multiply") + ewise_fma_op = Op.get("relax.ewise_fma") + + if call.op == add_op: + value = self.lookup_binding(call.args[0]) + if isinstance(value, Call) and value.op == multiply_op: + fma_call = Call( + ewise_fma_op, [value.args[0], value.args[1], call.args[1]], None, None + ) + return fma_call + + return call + + +@dataflowblock_pass(opt_level=2, name="ewise_fma_rewriter") +class EwiseRewriteFMA: + """The wrapper for the EwiseFMARewriter pass.""" + + def transform_dataflowblock(self, block, mod, ctx): + return EwiseFMARewriter().visit_binding_block(block) + + +@mutator +class EwiseFuseFMAMutator(PyExprMutator): + """Performs multiply add fusion. The difference of EwiseFMARewriter and this + EwiseFuseFMAMutator class is that this mutator generates a sub function(subgraph) + whose body is a CallNode that calls to the relax.ewise_fma op, and rewrites the + relax.add call in the main function to calling to the subgraph. + + Example + -------- + Before-transformation IRModule: + def main(): + x0 = mul(a, b) + z0 = add(x0, c) + --> + After-transformation IRModule: + def ewise_fused(x, y, z): + return relax.ewise_fma(x, y, z) + + def main(): + z0 = ewise_fused(a, b, c) + """ + + def __init__(self, mod: IRModule) -> None: + super().__init__() + self.mod_ = mod + + def transform(self) -> IRModule: + for global_var, func in self.mod_.functions.items(): + if isinstance(func, Function): + func = self.visit_expr(func) + self.builder_.add_func(func, global_var.name_hint) + + return self.builder_.get() + + def visit_call_(self, call: Call) -> Call: # pylint: disable=arguments-differ + call = self.visit_expr_post_order(call) + add_op = Op.get("relax.add") + multiply_op = Op.get("relax.multiply") + ewise_fma_op = Op.get("relax.ewise_fma") + + if call.op == add_op: + value = self.lookup_binding(call.args[0]) + if isinstance(value, Call) and value.op == multiply_op: + mul = value + # construct a subgraph + x = Var("x", mul.args[0].shape_, mul.args[0]._checked_type_) + y = Var("y", mul.args[1].shape_, mul.args[1]._checked_type_) + z = Var("z", call.args[1].shape_, call.args[1]._checked_type_) + body = Call(ewise_fma_op, [x, y, z]) + + func_name = "ewise_fma_fused" + # TODO: Possibly fill in the return shape + func = Function([x, y, z], body, call.args[1]._checked_type_, RuntimeDepShape()) + ewise_fma_fused = func.with_attr("global_symbol", func_name) + normalized = self.builder_.normalize(ewise_fma_fused) + global_var = self.builder_.add_func(normalized, "ewise_fma_fused") + + # construct a call to the subgraph + fma_call = Call(global_var, [mul.args[0], mul.args[1], call.args[1]], None, None) + + return fma_call + + return call + + +@module_pass(opt_level=2, name="ewise_fuse_fma_rewriter") +class EwiseFuseFMA: + """The wrapper for the EwiseFuseFMA pass.""" + + def transform_module(self, mod, ctx): + return EwiseFuseFMAMutator(mod).transform() diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py new file mode 100644 index 0000000000..cc37eb8cb3 --- /dev/null +++ b/python/tvm/relax/transform/transform.py @@ -0,0 +1,650 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Relax transformation passes.""" +import functools +import inspect +import types +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import tvm.ir +from tvm.runtime import NDArray +from . import _ffi_api + + +@tvm._ffi.register_object("relax.FunctionPass") +class FunctionPass(tvm.ir.transform.Pass): + """A pass that works on each tvm.relax.Function in a module. A function + pass class should be created through `function_pass`. + """ + + +@tvm._ffi.register_object("relax.DataflowBlockPass") +class DataflowBlockPass(tvm.ir.transform.Pass): + """A pass that works on each tvm.relax.DataflowBlock in a module.""" + + +def FailTestRewrite() -> tvm.ir.transform.Pass: + """Incorrectly transform the dataflow structure as fail testcases. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.FailTestRewrite() + + +def RewriteFMA() -> tvm.ir.transform.Pass: + """Perform fused multiply add rewriting in dataflow blocks. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.RewriteFMA() + + +def FuseFMA() -> tvm.ir.transform.Pass: + """Perform fused multiply add rewriting, generate a subgraph(sub function), + and call into the sub function in the main function. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.FuseFMA() + + +def LambdaLift(): + """ + Lift local functions into global. + + Returns + ------- + ret : tvm.ir.transform.Pass + """ + return _ffi_api.LambdaLift() + + +def ToNonDataflow() -> tvm.ir.transform.Pass: + """Transform all dataflow structure to non-dataflow version. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.ToNonDataflow() + + +def CallTIRRewrite() -> tvm.ir.transform.Pass: + """Perform explicit tensor allocation for call_tir. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CallTIRRewrite() + + +def VMMemoryLower() -> tvm.ir.transform.Pass: + """Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.VMMemoryLower() + + +def VMShapeLower() -> tvm.ir.transform.Pass: + """Lower the shape expressions in relax to VM shape heap manipulations and generate related + TIR functions to do shape calculations. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.VMShapeLower() + + +def Normalize() -> tvm.ir.transform.Pass: + """Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting + and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.Normalize() + + +def CanonicalizeBindings() -> tvm.ir.transform.Pass: + """ + Canonicalizes variable definitions + (e.g., if there is y = x and z = y, it replaces uses of y and z with x). + + Best combined with constant folding and the elimination of unused definitions. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CanonicalizeBindings() + + +def ResolveGlobals() -> tvm.ir.transform.Pass: + """Resolve global variables using string equality. This ensures all GlobalVars in the IR refer + to the correct GlobalVar of the input IRModule. An error is reported if any GlobalVar cannot be + resolved. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.ResolveGlobals() + + +def BindParams( + func_name: str, + params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]], +) -> tvm.ir.transform.Pass: + """Bind params of function of the module to constant tensors. + + Parameters + ---------- + + func_name: str + The function name to be bound + + params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]] + The map from param name to constant tensors. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + tvm_params = {} + for k, v in params.items(): + if isinstance(v, np.ndarray): + v = tvm.nd.array(v) + assert isinstance( + v, tvm.runtime.NDArray + ), f"param values are expected to be TVM.NDArray or numpy.ndarray, but got {type(v)}" + tvm_params[k] = v + + return _ffi_api.BindParams(func_name, tvm_params) + + +def RemoveUnusedFunctions(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass: + """Remove unused relax/prim functions without external linkage in a IRModule. + + Parameters + ---------- + entry_functions: Optional[List[str]] + The set of entry functions to start from. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass to remove unused functions. + """ + if entry_functions is None: + entry_functions = ["main"] + return _ffi_api.RemoveUnusedFunctions(entry_functions) + + +def RunCodegen( + target_codegens: Optional[List[str]] = None, entry_functions: Optional[List[str]] = None +) -> tvm.ir.transform.Pass: + """Produce the runtime::Module with an annotated codegen and global symbol. + + Parameters + ---------- + target_codegens: Optional[List[str]] + List of target codegens. If empty, perform all codegens by default. + entry_functions: Optional[List[str]] + The set of entry functions to start from. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass to remove unused functions. + """ + if entry_functions is None: + entry_functions = ["main"] + return _ffi_api.RunCodegen(target_codegens, entry_functions) + + +def FoldConstant() -> tvm.ir.transform.Pass: + """Fold constant expressions. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.FoldConstant() + + +def AnnotateTIROpPattern() -> tvm.ir.transform.Pass: + """Annotate Op Pattern Kind for TIR functions + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.AnnotateTIROpPattern() + + +def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass: + """This pass groups bindings in a dataflow block of Relax functions and generate a new grouped + Relax function for each group, according to the fusion algorithm described in the pass + implementation. By grouping bindings into new Relax functions, we substitute the bindings in + the function being manipulated into function calls to the new grouped function. + + A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + + Parameters + ---------- + fuse_opt_level : int + The level of fuse optimization. -1 indicates that the level will be + inferred from pass context. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for operator fusion. + """ + return _ffi_api.FuseOps(fuse_opt_level) + + +def FuseTIR() -> tvm.ir.transform.Pass: + """Fuse primitive relax function into a larger TIR function if possible + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for tir fusion. + """ + return _ffi_api.FuseTIR() + + +def MetaScheduleApplyDatabase( + work_dir: Optional[str] = None, +) -> tvm.ir.transform.Pass: + """Apply the best schedule from tuning database. + work_dir : Optional[str] + work directory to deduce default database if database is not provided + (it will be ignored when an user passes database) + Returns + ------- + ret : tvm.transform.Pass + The registered pass + """ + return _ffi_api.MetaScheduleApplyDatabase(work_dir) + + +def MetaScheduleTuneTIR( + work_dir: str, + max_trials_global: int, +) -> tvm.ir.transform.Pass: + """Tune TIR with MetaSchedule. + Parameters + ---------- + work_dir: str + work directory + max_trials_gloabl: int + maximum number of total trials allowed for tuning + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.MetaScheduleTuneTIR(work_dir, max_trials_global) + + +def MetaScheduleTuneIRMod( + params: Dict[str, NDArray], + work_dir: str, + max_trials_global: int, +) -> tvm.ir.transform.Pass: + """Tune Relax IRModule with MetaSchedule. + Parameters + ---------- + params: Dict[str, NDArray] + model params + work_dir: str + work directory + max_trials_gloabl: int + maximum number of total trials allowed for tuning + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global) + + +def AssignPoolInfo() -> tvm.ir.transform.Pass: + """Assign PoolInfo objects to Relax and TIR allocates depending on the function target + + This pass would assign default PoolInfo objects to allocates that are not otherwise + annotated, depending on pool info supplied for each target. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for assigning pool infos. + """ + return _ffi_api.AssignPoolInfo() + + +def _wrap_class_function_pass(pass_cls, pass_info): + """Wrap a python class as function pass.""" + + class PyFunctionPass(FunctionPass): + """Internal wrapper class to create a class instance.""" + + def __init__(self, *args, **kwargs): + # initialize handle in case pass_cls creation failed. + self.handle = None + inst = pass_cls(*args, **kwargs) + + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(func, mod, ctx): + return inst.transform_function(func, mod, ctx) + + self.__init_handle_by_constructor__(_ffi_api.MakeFunctionPass, _pass_func, pass_info) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__) + PyFunctionPass.__name__ = pass_cls.__name__ + PyFunctionPass.__doc__ = pass_cls.__doc__ + PyFunctionPass.__module__ = pass_cls.__module__ + return PyFunctionPass + + +def function_pass( + pass_func=None, + opt_level=None, + name=None, + required=None, + traceable=False, +) -> Union[Callable, FunctionPass]: + """Decorate a function pass. + + This function returns a callback when pass_func + is provided. Otherwise, it returns the created function pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]] + The transformation function or class. + + opt_level : int + The optimization level of this function pass. + + name : Optional[str] + The name of the function pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the function pass is dependent on. + + traceable: Boolean + Boolean variable whether the function pass is traceable + + Returns + ------- + create_function_pass : Union[Callable, FunctionPass] + + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new FunctionPass will be returned when we decorate a pass function. + A new FunctionPass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a function pass class. + + .. code-block:: python + + @relax.transform.function_pass(opt_level=1) + class TestReplaceFunc: + def __init__(self, new_func): + self.new_func = new_func + + def transform_function(self, func, mod, ctx): + # just for demo purposes + # transform func to new_func + return self.new_func + + @R.function + def f1(x: Tensor[(m, n), "float32"]): + return x + + @tvm.script.ir_module + class InputMod: + @R.function + def f2(x: Tensor[(m, n), "float32"]): + gv0 = relax.add(x, x) + return gv0 + # fpass is now a special pass that replaces every + # function to f1 + fpass = TestReplaceFunc(f1) + # now every function in InputMod is replaced by f1 + updated_mod = fpass(InputMod) + + + The following code creates a function pass by decorating + a user defined transform function. + + .. code-block:: python + + @relax.transform.function_pass(opt_level=2) + def transform(func, mod, ctx): + # my transformations here. + return func + + function_pass = transform + assert isinstance(function_pass, relax.transform.FunctionPass) + assert function_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = function_pass(m) + # Now transform should have been applied to every function in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the function pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + "list/tuple.") + + def create_function_pass(pass_arg): + """Internal function that creates a function pass""" + fname = name if name else pass_arg.__name__ + info = tvm.transform.PassInfo(opt_level, fname, required, traceable) + if inspect.isclass(pass_arg): + return _wrap_class_function_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Function pass") + return _ffi_api.MakeFunctionPass(pass_arg, info) + + if pass_func: + return create_function_pass(pass_func) + return create_function_pass + + +def _wrap_class_dataflowblock_pass(pass_cls, pass_info): + """Wrap a python class as dataflowblock pass""" + + class PyDataflowBlockPass(DataflowBlockPass): + """Internal wrapper class to create a class instance.""" + + def __init__(self, *args, **kwargs): + # initialize handle in case pass_cls creation failed. + self.handle = None + inst = pass_cls(*args, **kwargs) + + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(func, mod, ctx): + return inst.transform_dataflowblock(func, mod, ctx) + + self.__init_handle_by_constructor__( + _ffi_api.MakeDataflowBlockPass, _pass_func, pass_info + ) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyDataflowBlockPass.__init__, pass_cls.__init__) + PyDataflowBlockPass.__name__ = pass_cls.__name__ + PyDataflowBlockPass.__doc__ = pass_cls.__doc__ + PyDataflowBlockPass.__module__ = pass_cls.__module__ + return PyDataflowBlockPass + + +def dataflowblock_pass( + pass_func=None, opt_level=None, name=None, required=None, traceable=False +) -> Union[Callable, DataflowBlockPass]: + """Decorate a dataflowblock pass. + + This function returns a callback when pass_func + is provided. Otherwise, it returns the created dataflowblock pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(DataflowBlock, Module, PassContext) -> DataflowBlock]] + The transformation function or class. + + opt_level : int + The optimization level of this dataflowblock pass. + + name : Optional[str] + The name of the dataflowblock pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the dataflowblock pass is dependent on. + + traceable: Boolean + Boolean variable whether the dataflowblock pass is traceable + + Returns + ------- + create_dataflowblock_pass : Union[Callable, DataflowBlockPass] + + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new DataflowBlockPass will be returned when we decorate a pass function. + A new DataflowBlockPass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a dataflowblock pass class. + + .. code-block:: python + + @relax.transform.dataflowblock_pass(opt_level=1) + class TestReplaceBinding: + # Simple test function to replace the first VarBinding to another. + + def __init__(self): + # create a new VarBinding + m, n = tir.Var("m", "int64"), tir.Var("n", "int64") + type_anno = relax.DynTensorType(2, "float32") + lv0 = relax.Var("lv1", [m, n], type_anno) + val = relax.const(np.random.rand(24, 56)) + self.new_binding = relax.VarBinding(lv0, val) + + def transform_dataflowblock(self, block, mod, ctx): + # just for demo purposes + # Replace the first binding in the DataflowBlock + new_bindings = [self.new_binding, block.bindings[1]] + new_block = relax.expr.DataflowBlock(new_bindings, block.span) + return new_block + + @tvm.script.ir_module + class InputMod: + @R.function + def f1(x: Tensor[(m, n), "float32"]): + with relax.dataflow(): + lv0 = relax.multiply(x, x) + gv0 = relax.add(x, x) + relax.output(gv0) + return gv0 + # block_pass is now a special pass that replaces every + # first binding to the constant value binding + block_pass = TestReplaceBinding() + # now every first binding in DataflowBlock of InputMod + # is replaced by new_binding + updated_mod = block_pass(InputMod) + + + The following code creates a dataflowblock pass by decorating + a user defined transform function. + + .. code-block:: python + + @relax.transform.dataflowblock_pass(opt_level=2) + def transform(block, mod, ctx): + # my transformations here. + return block + + block_pass = transform + assert isinstance(block_pass, relax.transform.DataflowBlockPass) + assert block_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = block_pass(m) + # Now transform should have been applied to every DataflowBlock in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the dataflowblock pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + "list/tuple.") + + def create_dataflowblock_pass(pass_arg): + """Internal function that creates a dataflowblock pass""" + fname = name if name else pass_arg.__name__ + info = tvm.transform.PassInfo(opt_level, fname, required, traceable) + if inspect.isclass(pass_arg): + return _wrap_class_dataflowblock_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for DataflowBlock pass") + return _ffi_api.MakeDataflowBlockPass(pass_arg, info) + + if pass_func: + return create_dataflowblock_pass(pass_func) + return create_dataflowblock_pass diff --git a/python/tvm/relax/transform/tuning_api/__init__.py b/python/tvm/relax/transform/tuning_api/__init__.py new file mode 100644 index 0000000000..6c39d5c535 --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax Tunign Pass API""" + +from .primitives import * +from .default_functions import * +from .database import * diff --git a/python/tvm/relax/transform/tuning_api/_ffi_api.py b/python/tvm/relax/transform/tuning_api/_ffi_api.py new file mode 100644 index 0000000000..f31522d025 --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for relax.tuning_api""" +import tvm._ffi + +tvm._ffi._init_api("relax.tuning_api", __name__) diff --git a/python/tvm/relax/transform/tuning_api/database.py b/python/tvm/relax/transform/tuning_api/database.py new file mode 100644 index 0000000000..9477e142ba --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/database.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API default functions""" +from typing import List, Optional +import logging + +from tvm.runtime import Object +from tvm.ir.module import IRModule +from tvm.meta_schedule.utils import _json_de_tvm +from tvm.meta_schedule.database import Workload +from tvm.tir.schedule.trace import JSON_TYPE +from tvm.target import Target +from tvm._ffi import register_object +from .primitives import Trace +from . import _ffi_api + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + + +@register_object("relax.tuning_api.TuningRecord") +class TuningRecord(Object): + """The class of tuning records. + + Parameters + ---------- + trace : tvm.relax.transform.tuning_api.Trace + The trace of the tuning record. + run_secs : Optional[List[float]] + The run-time of the tuning record. + """ + + trace: Trace + run_secs: Optional[List[float]] + + def __init__( # type: ignore # pylint: disable=too-many-arguments + self, + trace: Trace, + run_secs: Optional[List[float]] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.TuningRecord, # type: ignore # pylint: disable=no-member + trace, + run_secs, + ) + + def as_json(self, include_irmod: bool = False) -> JSON_TYPE: + """Export the tuning record to a JSON string. + Parameters + ---------- + include_irmod: bool + Decides whether to serialize in_mod as well. + + Returns + ------- + json_str : str + The JSON string exported. + """ + return _json_de_tvm(_ffi_api.TuningRecordAsJSON(self, include_irmod)) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "TuningRecord": + """Create a tuning record from a json object. + + Parameters + ---------- + json_obj : JSON_TYPE + The json object to parse. + + Returns + ------- + tuning_record : TuningRecord + The parsed tuning record. + """ + return _ffi_api.TuningRecordFromJSON(json_obj) # type: ignore # pylint: disable=no-member + + +@register_object("relax.tuning_api.Database") +class Database(Object): + """The abstract database interface.""" + + def has_workload(self, mod: IRModule) -> bool: + """Check if the database has the given workload. + Parameters + ---------- + mod : IRModule + The IRModule to be searched for. + + Returns + ------- + result : bool + Whether the given workload is committed. + """ + return _ffi_api.DatabaseHasWorkload(self, mod) # type: ignore # pylint: disable=no-member + + def has_measurement_record(self, workload: Workload, target: Target) -> bool: + """Check if the database has a measurement record for the given workload and target pair. + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + result : bool + Whether the given workload and target pair is committed for the measurement record. + """ + return _ffi_api.DatabaseHasMeasurementRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def has_tuning_record(self, workload: Workload, target: Target) -> bool: + """Check if the database has a tuning record for the given workload and target pair. + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + result : bool + Whether the given workload and target pair is committed for the tuning record. + """ + return _ffi_api.DatabaseHasTuningRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def commit_workload(self, mod: IRModule) -> Workload: + """Commit a workload to the database if missing. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for or added. + + Returns + ------- + workload : Workload + The workload corresponding to the given IRModule. + """ + return _ffi_api.DatabaseCommitWorkload(self, mod) # type: ignore # pylint: disable=no-member + + def commit_measurement_record( + self, workload: Workload, target: Target, run_secs: List[float] + ) -> None: + """Commit a measurement record to the database. + A pair of workload and target will be used as a key. + + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + run_secs : Optional[List[float]] + The measurement record to add. + """ + _ffi_api.DatabaseCommitMeasurementRecord(self, workload, target, run_secs) # type: ignore # pylint: disable=no-member + + def commit_tuning_record( + self, workload: Workload, target: Target, record: TuningRecord + ) -> None: + """Commit a tuning record to the database. + A pair of workload and target will be used as a key. + + Parameters + ---------- + workload: Workload + The workload to be searched for. + target: Target + The target to be searched for. + record : TuningRecord + The tuning record to add. + """ + _ffi_api.DatabaseCommitTuningRecord(self, workload, target, record) # type: ignore # pylint: disable=no-member + + def get_measurement_record(self, workload: Workload, target: Target) -> Optional[List[float]]: + """Get the measurement record of given workload and target from the database. + + Parameters + ---------- + workload : Workload + The workload to be searched for. + target: Target + The target to be searched for. + + Returns + ------- + measurement_record : Optional[List[float]] + Measurement record if exists. + """ + return _ffi_api.DatabaseGetMeasurementRecord(self, workload, target) # type: ignore # pylint: disable=no-member + + def get_top_k(self, workload: Workload, target: Target, top_k: int) -> List[TuningRecord]: + """Get the top K tuning records of given workload and target from the database. + + Parameters + ---------- + workload : Workload + The workload to be searched for. + target: Target + The target to be searched for. + top_k : int + The number of top records to get. + + Returns + ------- + top_k_records : List[TuningRecord] + The top K records. + """ + return _ffi_api.DatabaseGetTopK(self, workload, target, top_k) # type: ignore # pylint: disable=no-member + + +@register_object("relax.tuning_api.JSONDatabase") +class JSONDatabase(Database): + """The class of JSON database. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + Manages pairs of + path_measurement_record : str + The path to the path_measurement_record table. + Manages pairs of + """ + + path_workload: str + path_tuning_record: str + path_measurement_record: str + + def __init__( + self, + path_workload: str, + path_tuning_record: str, + path_measurement_record: str, + allow_missing: bool = True, + ) -> None: + """Constructor. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + path_measurement_record : str + The path to the path_measurement_record table. + allow_missing : bool + Whether to create new file when the given path is not found. + """ + self.__init_handle_by_constructor__( + _ffi_api.DatabaseJSONDatabase, # type: ignore # pylint: disable=no-member + path_workload, + path_tuning_record, + path_measurement_record, + allow_missing, + ) diff --git a/python/tvm/relax/transform/tuning_api/default_functions.py b/python/tvm/relax/transform/tuning_api/default_functions.py new file mode 100644 index 0000000000..30b2d69b1d --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/default_functions.py @@ -0,0 +1,304 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API default functions""" +from typing import Dict, List, Optional +import sys +import itertools +import logging +import numpy as np + +import tvm +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext, Pass +from tvm import meta_schedule +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder +from tvm.meta_schedule.utils import get_global_func_with_default_on_worker +from tvm.meta_schedule.runner import ( + EvaluatorConfig, + LocalRunner, + RunnerInput, +) +from tvm._ffi.registry import register_func +from .primitives import Knob, Trace + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + +# Default transform func that returns original IRModule. +@tvm.register_func("relax.tuning_api.Choice.default_transform_func") +def default_transform_func(mod): + return mod + + +# Default constraint func that always returns true. +@tvm.register_func("relax.tuning_api.Choice.default_constr_func") +def default_constr_func(mod: IRModule) -> bool: # pylint: disable=unused-argument + return True + + +@register_func("relax.tuning_api.default_generate_candidate") +def default_generate_candidate( + knobs: List[Knob], trace: Trace, eval_passes: Optional[List[Pass]] = None +) -> List[Trace]: + """ + Default function to generate the search space for a given trace by using registered choices. + This function simply expands candidate space as long as the knob's constraint satisfies. + To reduce the search space, a developer may expand each choice with smart search method. + (e.g., genetic search, multi-armed bandit) + Note, each pass generates candidates without worrying about the interaction with other passes. + i.e., it only uses its incoming trace/IRModule and Choices for candidate generation. + This will help alleviating the complexity of joint-optimization significantly. + - consideration of interaction between optimizations has known to be extremely difficult. + + Parameters + ---------- + knobs : List[Knob] + List of Knobs to consider to generate candidate for input trace. + trace: Trace + Input trace. + eval_passes: Optional[List[Pass]] + List of passes to consider to evaluate each candidate. + This will enable joint-optimization. + + Return + ---------- + candidates: List[Trace] + List of candidate traces + """ + + candidates = [trace] + # Iterate over every decision + for knob in knobs: + num = len(candidates) + for _ in range(num): + cur_trace = candidates.pop(0) + for decision in knob.choices.keys(): + choice = knob.choices[decision] + # Generate new candidate when this condition satisfies. + if choice.check_constr(cur_trace.out_mod): + new_trace = cur_trace.deepcopy() + new_trace.add(knob, decision) + candidates.append(new_trace) + + # Expand candidates by using eval passes if provided. This will enable joint-optimization. + if eval_passes: + candidates = default_consider_eval_passes(candidates, eval_passes) + return candidates + + +@register_func("relax.tuning_api.default_consider_eval_passes") +def default_consider_eval_passes( + init_candidates: List[Trace], eval_passes: Optional[List[Pass]] = None +) -> List[Trace]: + """ + Default function to update traces with eval passes. + It visits each eval_pass in dfs order in transform.Sequential() and + returns the best possible candidate trace for each candidate. + + Parameters + ---------- + init_candidates: List[Trace] + Initial candidates + eval_passes: Optional[List[Pass]] + List of passes to consider to evaluate each candidate. + This will enable joint-optimization. + Return + ---------- + candidates: List[Trace] + List of candidate traces + """ + if not eval_passes: + return init_candidates + + eval_passes = list(eval_passes) if not isinstance(eval_passes, list) else eval_passes + ctx = PassContext.current() + candidates = [] + + for trace in init_candidates: + ctx.push_trace(trace) + tvm.transform.Sequential(eval_passes)(trace.out_mod) + new_trace = ctx.pop_trace() + # A new trace contains the best decisions in eval_passes + candidates.append(new_trace) + + return candidates + + +@register_func("relax.tuning_api.default_evaluate") +def default_evaluate( + candidates: List[Trace], + target_str: str, + params: Optional[Dict[str, np.ndarray]] = None, + builder: Optional[meta_schedule.builder.Builder] = None, + runner: Optional[meta_schedule.runner.Runner] = None, +) -> None: + """ + Default function to evaluate a set of candidate traces by using MetaSchedule builder/runner. + + Parameters + ---------- + candidates: List[Trace] + List of traces to evaluate. + target_str: str, + Compilation target (e.g., llvm, cuda). + params: Optional[Dict[str, np.ndarray]] + Params to bind. + builder: Optional[meta_schedule.builder.Builder] + builder function. If not provided, default local builder will be used. + runner: Optional[meta_schedule.runner.Runner] + runner function. If not provided, default local runner will be used. + """ + + ctx = PassContext.current() + target = tvm.target.Target(target_str) + database = PassContext.current().get_tuning_api_database() + # Setup default local builder if not provided + if builder is None: + + def relax_build( + mod: IRModule, + target: tvm.target.Target, + params: Optional[Dict[str, np.ndarray]], + ): + if params: + mod = tvm.relax.transform.BindParams("main", params)(mod) + relax_exec = tvm.relax.vm.build(mod, target) + return relax_exec.mod + + builder = LocalBuilder(f_build=relax_build) + + # Setup default local runner if not provided + if runner is None: + + def relax_eval_func(rt_mod, device, evaluator_config, repeated_args): + relax_exec = tvm.relax.vm.Executable(rt_mod) + relax_vm = tvm.relax.VirtualMachine(exec=relax_exec, device=device) + + evaluator = relax_vm.module.time_evaluator( + func_name="main", + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + + return costs + + runner = LocalRunner( + evaluator_config=EvaluatorConfig( + number=3, repeat=5, min_repeat_ms=100, enable_cpu_cache_flush=False + ), + f_run_evaluator=relax_eval_func, + ) + + # set up clean up function + f_clean_build = get_global_func_with_default_on_worker("meta_schedule.remove_build_dir", None) + assert f_clean_build + + # Keep track of number of evaluations (mostly for the debugging purpose) + num_evals = 0 + # Evaluation + for candidate in candidates: + # If this candidate is already evaluated, skip the measurement + if candidate.perf != -1: + continue + + # Evaluate candidates + num_evals += 1 + mod = candidate.out_mod + workload = database.commit_workload(mod) + + # If this workload and target pair has measured before, fetch its data. + if database.has_measurement_record(workload, target): + run_secs = database.get_measurement_record(workload, target) + # Otherwise, measure it. + else: + # Build candidate + (builder_result,) = builder.build([BuilderInput(mod, target, params)]) + + if builder_result.artifact_path is None: + # Build error + # Assign the worst performance and move on to the next candidate. + logger.warning(builder_result.error_msg) + run_secs = [1e100] + else: + # If build passes, set up runner input and measure the performance. + args_info = [ + TensorInfo(shape=[int(i) for i in p.shape], dtype=p.checked_type.dtype) + for p in mod["main"].params + ] # convert list[Var] to list[TensorInfo] + runner_input = RunnerInput( + builder_result.artifact_path, target_str, args_info=args_info + ) + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + + run_secs = runner_result.run_secs + # Runtime error + # Assign the worst performance and move on to the next candidate. + if runner_result.error_msg is not None: + logger.warning(runner_result.error_msg) + run_secs = [1e100] + + database.commit_measurement_record(workload, target, run_secs) + + # Clean up the artifact + f_clean_build(builder_result.artifact_path) + + # For valid measurments, compute the average and update the trace performance. + perfs = [] + for result in run_secs: + if isinstance(result, tvm.tir.FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + perfs.append(result) + + # Store the evaluation result + candidate.set_perf(np.mean(perfs)) + + ctx.inc_num_evals(num_evals) + + +def select_best_candidate(candidates: List[Trace]) -> Trace: + """ + Select the best trace. + + Parameters + ---------- + candidates: List[Trace] + Candidate traces + + Return + ---------- + best_trace: Trace + Trace with the best performance + """ + best_perf, best_trace = sys.maxsize, None + for candidate in candidates: + avg = candidate.perf + # Select best one + if best_perf > avg: + best_perf = avg + best_trace = candidate + return best_trace diff --git a/python/tvm/relax/transform/tuning_api/primitives.py b/python/tvm/relax/transform/tuning_api/primitives.py new file mode 100644 index 0000000000..23b2101545 --- /dev/null +++ b/python/tvm/relax/transform/tuning_api/primitives.py @@ -0,0 +1,419 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax Tuning Pass API primitives""" + +from typing import Callable, Union, Dict, List, Optional +import logging +import tvm +from tvm.runtime import Object +from tvm.ir.module import IRModule +from tvm.relax import Expr +from tvm.tir.schedule.trace import JSON_TYPE, _json_from_tvm +from tvm._ffi import register_object +from . import _ffi_api + +logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name + + +@register_object("relax.tuning_api.Choice") +class Choice(Object): + """ + A TVM object Choice that maintains a set of transformation and constraint function keys. + Corresponding functions should be registered as PackedFunc with these keys. + Transformation function will be applied when constraint function returns true. + Parameters + ---------- + transform_func_key : Optional[str] + Key for transformation function. + transform_func_args : Optional[List] + Arguments for transformation function. + constr_func_key : Optional[str] + Key for constraint function. + constr_func_args : Optional[List] + Arguments for constraint function. + + Examples + -------- + The following code block defines a Choice. + + .. code-block:: python + @tvm.register_func("relax.tuning_api.test.transform_func") + def apply(mod): + return relax.tuning_api.FoldConstant()(mod) + @tvm.register_func("relax.tuning_api.test.constr_func") + def constr(mod): + return len(mod.functions) == 3 + # Define a choice to apply constant folding only when IRModule has three functions. + choice = Choice( + transform_func_key = "relax.tuning_api.test.transform_func", + constr_func_key = "relax.tuning_api.test.constr_func" + ) + """ + + def __init__( + self, + transform_func_key: Optional[str] = None, + transform_func_args: Optional[List] = None, + constr_func_key: Optional[str] = None, + constr_func_args: Optional[List] = None, + ): + """Constructor + Parameters + ---------- + transform_func_key : Optional[str] + Key for transformation function. + + f_tramsform_args: Optional[List] + Arguments for transformation function. + + constr_func_key : Optional[str] + Key for constraint function. + + constr_func_args: Optional[List] + Arguments for constraint function. + """ + + if transform_func_key is None: + transform_func_key = "relax.tuning_api.Choice.default_transform_func" + + if transform_func_args is None: + transform_func_args = [] + + if constr_func_key is None: + constr_func_key = "relax.tuning_api.Choice.default_constr_func" + + if constr_func_args is None: + constr_func_args = [] + + self.__init_handle_by_constructor__( + _ffi_api.Choice, + transform_func_key, + transform_func_args, + constr_func_key, + constr_func_args, # type: ignore # pylint: disable=no-member + ) + + def get_transform_func(self) -> Callable: + """Getter for transform_func + Returns + ------- + ret: Callable + registered transformation function + """ + return _ffi_api.ChoiceGetTransformFunc(self) + + def get_constr_func(self) -> Callable: + """Getter for constr_func + Returns + ------- + ret: Callable + registered constraint function + """ + return _ffi_api.ChoiceGetConstrFunc(self) + + def apply_transform_func(self, mod: IRModule) -> IRModule: + """Perform transform_func with its arguments + Returns + ------- + ret: IRModule + Transformed IRModule + """ + return _ffi_api.ChoiceApplyTransformFunc(self, mod) + + def check_constr(self, mod: IRModule) -> bool: + """Perform constr_func with its arguments + Returns + ------- + ret: bool + Returns whether the IRModule satisfies the constraint or not + """ + return _ffi_api.ChoiceCheckConstr(self, mod) + + def as_json(self) -> JSON_TYPE: + """Serialize the trace as a JSON-style object + Returns + ------- + json: JSON_TYPE + The JSON-style object + """ + return _ffi_api.ChoiceAsJSON(self) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Choice": + """Create Choice from JSON obj + + Parameters + ---------- + json_obj: JSON_TYPE + Choice serialized with JSON + + Return + ---------- + choice: Choice + Deserialized choice + """ + return _ffi_api.ChoiceFromJSON(json_obj) + + def deepcopy(self): + return Choice.from_json(self.as_json()) + + +@register_object("relax.tuning_api.Knob") +class Knob(Object): + """ + A TVM object Knob that maintains a set of valid Choices. + By using Knobs, a tuning pass can generate candidates and define the search space. + Parameters + ---------- + name : str + Name of the knob. + + choices: Union[List[Choice], Dict[str, Choice]] + A list of valid choices + + Examples + -------- + The following code block defines a Knob. + + .. code-block:: python + @tvm.register_func("relax.tuning_api.test.transform_func") + def apply(mod): + return relax.tuning_api.FoldConstant()(mod) + choices = {"apply": Choice("relax.tuning_api.test.transform_func"), "noapply": Choice()} + # A knob manages a set of its valid choices + knob = Knob("MockTuningKnob", choices) + """ + + def __init__(self, name: str, choices: Union[List[Choice], Dict[str, Choice]]): + """Constructor.""" + if isinstance(choices, list): + choices = {str(idx): val for idx, val in enumerate(choices)} + + self.__init_handle_by_constructor__( + _ffi_api.Knob, name, choices # type: ignore # pylint: disable=no-member + ) + + def verify(self, decision: Union[str, int]) -> bool: + """Verify if the decision is valid.""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.KnobIsValidDecision(self, decision) + + def apply(self, mod: IRModule, decision: Union[str, int]) -> IRModule: + """Get choice if a decision is valid.""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.KnobApply(self, mod, decision) + + def as_json(self) -> JSON_TYPE: + """Serialize the trace as a JSON-style object + Returns + ------- + json: JSON_TYPE + The JSON-style object + """ + return _ffi_api.KnobAsJSON(self) + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Knob": + """Create Knob from JSON obj + + Parameters + ---------- + json_obj: JSON_TYPE + Knob serialized with JSON + + Return + ---------- + knob: Knob + Deserialized knob + """ + return _ffi_api.KnobFromJSON(json_obj) + + def __str__(self) -> str: + msg = f"{self.name} (# of choices: {len(self.choices)})\n" + for name, choice in self.choices.items(): + msg += f" - {name}: {choice}\n" + return msg + + def deepcopy(self): + return Knob.from_json(self.as_json()) + + +@register_object("relax.tuning_api.Trace") +class Trace(Object): + """ + A TVM object Trace logs the history of transformations (decisions). + Parameters + ---------- + in_mod : IRModule + Input IRModule. + knobs: Optional[List[Knob]] + A list of knobs applied in the trace. + decisions: Optional[List[Union[str, int]]] + A list of decisions made for each knob + + Examples + -------- + The following code block defines a Trace. + + .. code-block:: python + + trace = Trace(mod, [knob1, knob2, knob3], ["c1", "c0", "c3"]) + assert trace.size == 3 # Length of history. + # 'out' contains IRModule that applies transformations in the trace. + out: IRModule = trace.add(knob4, "c2") + assert trace.size == 4 # Length of history. + trace.set_perf(0.03) # Set the performance number of the trace. + """ + + def __init__( + self, + in_mod: IRModule, + knobs: Optional[List[Knob]] = None, + decisions: Optional[List[Union[str, int]]] = None, + ): + """Constructor.""" + knobs = knobs if knobs else list() + decisions = ( + [str(v) if isinstance(v, int) else v for v in decisions] if decisions else list() + ) + self.__init_handle_by_constructor__( + _ffi_api.Trace, in_mod, knobs, decisions # type: ignore # pylint: disable=no-member + ) + + def verify(self) -> bool: + """Verify if current history is valid.""" + return _ffi_api.TraceVerify() + + def add(self, knob: Knob, decision: Union[str, int]) -> IRModule: + """Add & Apply new decision (with knob).""" + if isinstance(decision, int): + decision = str(decision) + return _ffi_api.TraceAdd(self, knob, decision) + + def set_perf(self, perf: float) -> None: + """Set performance number for the trace.""" + return _ffi_api.TraceSetPerf(self, perf) + + def set_out_mod(self, mod: IRModule) -> None: + """Set out_mod for the trace.""" + return _ffi_api.TraceSetOutMod(self, mod) + + def as_json(self, include_irmod: bool = True) -> JSON_TYPE: + """Serialize the trace as a JSON-style object. + Parameters + ---------- + include_irmod: bool + Decides whether to serialize in_mod as well. + + Returns + ------- + json: JSON_TYPE + The JSON-style object. + """ + obj = _ffi_api.TraceAsJSON(self, include_irmod) + return _json_from_tvm(obj) + + @staticmethod + def from_json(json_obj: JSON_TYPE) -> "Trace": + """Create Trace from JSON obj. + + Parameters + ---------- + json_obj: JSON_TYPE + Trace serialized with JSON. + + Return + ---------- + trace: Trace + Deserialized trace. + """ + return _ffi_api.TraceFromJSON(json_obj) + + def __str__(self) -> str: + n = len(self.knobs) + msg = f"Trace length: {n}\n" + for idx in range(n): + msg += f"[{idx+1}] {self.knobs[idx].name}: {self.decisions[idx]}\n" + return msg + + def deepcopy(self) -> "Trace": + new_in_mod = deepcopy_irmodule(self.in_mod) + new_knobs = [knob.deepcopy() for knob in self.knobs] + new_decisions = [str(decision) for decision in self.decisions] + new_trace = Trace(new_in_mod, new_knobs, new_decisions) + new_out_mod = deepcopy_irmodule(self.out_mod) + new_trace.set_out_mod(new_out_mod) + return new_trace + + +def get_trace(in_: Union[Trace, IRModule, Expr]) -> Trace: + """ + Getter for a trace wrapper. + + Parameters + ---------- + in_: Union[Trace, IRModule, Expr] + Input entity + Return + ---------- + wrapped: Trace + Traced entity + """ + if isinstance(in_, Trace): + return in_ + if isinstance(in_, IRModule): + return Trace(in_) + if isinstance(in_, Expr): + return Trace(tvm.IRModule.from_expr(in_)) + + raise Exception(f"Invalid input type for trace: {type(in_)}") + + +@tvm.register_func("relax.tuning_api.deepcopy_irmodule") +def deepcopy_irmodule(mod: IRModule) -> IRModule: + """ + Deepcopy for an IRModule. + Parameters + ---------- + mod: IRModule + input IRModule + Return + ---------- + copied_mod: IRModule + deep-copied IRModule + """ + func_save_json = tvm.get_global_func("node.SaveJSON") + func_load_json = tvm.get_global_func("node.LoadJSON") + new_mod = None + # Handle external modules separately if exist + # TODO(tvm-team): + # Serialization of IRModule with external mods is tricky. + # (1) External mod is runtime module. + # (2) Currently, `export_library` does not support serialization of + # runtime module without the host module + # Therefore, we simply pass around the compiled external modules without copy for now. + # Revisit later when we have a better solution. + if mod.attrs and "external_mods" in mod.attrs: + tmp_mod = mod.without_attr("external_mods") + new_mod = func_load_json(func_save_json(tmp_mod)) + new_mod = new_mod.with_attr("external_mods", mod.attrs["external_mods"]) + else: + new_mod = func_load_json(func_save_json(mod)) + + return new_mod diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py new file mode 100644 index 0000000000..9bcb95c215 --- /dev/null +++ b/python/tvm/relax/ty.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import +"""The type nodes of the Relax language.""" +import tvm._ffi +from tvm.ir import Type, TensorType, TupleType, FuncType, Span + +from . import _ffi_api + + +@tvm._ffi.register_object("relax.ShapeType") +class ShapeType(Type): + """The type of shape in Relax.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeType, span) + + +@tvm._ffi.register_object("relax.ObjectType") +class ObjectType(Type): + """A type that corresponds to tvm::runtime::Object, is base of all possible object + values in TVM.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.ObjectType, span) + + +@tvm._ffi.register_object("relax.DynTensorType") +class DynTensorType(Type): + """A dynamic tensor type in Relax. + + This is the type assigned to tensors with a known dtype and unknown shape. + + Parameters + ---------- + ndim : Optional[int] + The ndim of the Tensor + + dtype : Optional[str] + The content data type. + """ + + def __init__(self, ndim=-1, dtype="float32", span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.DynTensorType, ndim, dtype, span) + + +@tvm._ffi.register_object("relax.DimType") +class DimType(Type): + """The type of indices/shape dimensions in Relax.""" + + def __init__(self, span: Span = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.DimType, span) + + +def is_base_of(base: Type, derived: Type) -> bool: + """Check the subtype relationship between base and derived. + + Parameters + ---------- + base : Type + The base type. + + derived : Type + The derived type. + + + Returns + ------- + ret : bool + If derived is a subtype of base or if both are the same type, returns true. + Otherwise returns false. + """ + return _ffi_api.IsBaseOf(base, derived) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py new file mode 100644 index 0000000000..bc8d41774b --- /dev/null +++ b/python/tvm/relax/utils.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utility functions for Relax""" +from typing import List + + +def metadata_partitioner(rx_txt: str) -> List[str]: + """Extract Relax program and metadata section. + + Parameters + ---------- + rx_txt : str + The input relax text. + + Returns + ------- + output : List[str] + The result list of partitioned text, the first element + is the relax program, and the second is metadata section. + """ + partitions = [] + left_curly = 0 + meta_start = 0 + meta_end = 0 + for i, char in enumerate(rx_txt): + if i < 0: + raise ValueError("The program is invalid.") + if char == "{": + if meta_start == 0: + meta_start = i + left_curly += 1 + elif char == "}": + left_curly -= 1 + if left_curly == 0: + meta_end = i + 1 + break + + if meta_end == 0: + raise ValueError("The metadata section was not found.") + metadata = rx_txt[meta_start:meta_end] + rx_program = rx_txt[meta_end:-1] + + partitions.append(rx_program) + partitions.append(metadata) + + return partitions diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py new file mode 100644 index 0000000000..177f2e1a9a --- /dev/null +++ b/python/tvm/relax/vm.py @@ -0,0 +1,538 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, redefined-builtin, no-else-return +"""The Relax virtual machine""" +from typing import Callable, List, Optional, Union, Dict, Tuple +import numpy as np + +from tvm._ffi import base as _base +import tvm +from tvm import relax +from tvm.ir.module import IRModule +from tvm.relay import Any +from tvm.runtime import Device, Module, PackedFunc, container +from tvm.runtime.object import Object +from tvm.tir.function import PrimFunc +from . import _ffi_api +from ..rpc.base import RPC_SESS_MASK + + +class Executable(object): + """The executable object emitted by the VM compiler or the ExecBuilder.""" + + def __init__(self, mod: Module): + self.mod = mod + self._stats = self.mod["stats"] + self._as_text = self.mod["as_text"] + self._as_python = self.mod["as_python"] + + def stats(self) -> str: + """print the detailed statistics of the executable.""" + return self._stats() + + def as_text(self) -> str: + """print the instructions as text format.""" + return self._as_text() + + def as_python(self) -> str: + """print the instructions as python program.""" + return self._as_python() + + +class VirtualMachine(object): + """Relax VM runtime.""" + + NAIVE_ALLOCATOR = 1 + POOLED_ALLOCATOR = 2 + + def __init__( + self, + exec: Union[Executable, Module], + device: Union[Device, List[Device]], + memory_cfg: Optional[Union[str, Dict[Device, str]]] = None, + ) -> None: + """ + Construct a VirtualMachine wrapper object. + + Parameters + ---------- + exec: Union[Executable, Module] + The VM executable or Runtime Module + + device : Union[Device, List[Device]] + The device to deploy the module. + + memory_cfg : Optional[Union[str, Dict[Device, str]]] + Config the type of memory allocator. The allocator type can be ["naive", + "pooled"]. If memory_cfg is None, all devices will use pooled allocator + by default. If memory_cfg is string, all devices will use the specified + allocator type. If memory_cfg is a dict, each device uses the allocator + type specified in the dict, or pooled allocator if not specified in the + dict. + """ + self.module = ( + exec.mod["vm_load_executable"]() + if isinstance(exec, Executable) + else exec["vm_load_executable"]() + ) + self._invoke_closure = self.module["invoke_closure"] + self._save_function = self.module["save_function"] + self._set_input = self.module["set_input"] + self._invoke_stateful = self.module["invoke_stateful"] + self._get_output = self.module["get_output"] + self._get_output_arity = self.module["get_output_arity"] + self._get_function_arity = self.module["get_function_arity"] + self._get_function_param_name = self.module["get_function_param_name"] + self._setup_device(device, memory_cfg) + + def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]]) -> None: + """init devices and allocators.""" + devs = dev + if not isinstance(dev, (list, tuple)): + if not isinstance(dev, tvm.runtime.Device): + raise TypeError( + "dev is expected to be Device or \ + List[Device]" + ) + devs = [dev] + + if any(dev.device_type % RPC_SESS_MASK == tvm.cpu().device_type for dev in devs[:-1]): + raise RuntimeError( + "CPU host is required to be the last element of the device list if provided." + ) + + # CPU is required for executing shape functions + if devs[-1].device_type % RPC_SESS_MASK != tvm.cpu().device_type: + devs.append(tvm.cpu()) + + default_alloc_type = VirtualMachine.POOLED_ALLOCATOR + if memory_cfg is None: + memory_cfg = {} + elif isinstance(memory_cfg, str): + assert memory_cfg in ["naive", "pooled"] + if memory_cfg == "naive": + default_alloc_type = VirtualMachine.NAIVE_ALLOCATOR + memory_cfg = {} + elif not isinstance(memory_cfg, dict): + raise TypeError( + "memory_cfg is expected be string or dictionary, " + + "but received {}".format(type(memory_cfg)) + ) + init_args = [] + for device in devs: + init_args.append(device.device_type % RPC_SESS_MASK) + init_args.append(device.device_id) + alloc_type = memory_cfg[device] if device in memory_cfg else default_alloc_type + init_args.append(alloc_type) + self.module["vm_initialization"](*init_args) + + def __getitem__(self, key: str) -> PackedFunc: + return self.module[key] + + def invoke_closure(self, closure: Object, *args: Any) -> Object: + """Invoke a closure. + + Parameters + ---------- + closure : Object + The VMClosure Object. + + args : list[tvm.runtime.NDArray] or list[np.ndarray] + The arguments to the closure. + + Returns + ------- + result : Object + The output. + """ + return self._invoke_closure(closure, *args) + + def save_function( + self, + func_name: str, + saved_name: str, + *args: List[Any], + include_return: bool = True, + **kwargs: Dict[str, Any], + ) -> None: + """ + Convenience function. Takes a function from the module and saves + a `PackedFunc` that, when called, will invoke the function with the given arguments. + The `PackedFunc` can be accessed from the module using `saved_name`. + This is included to facilitate timing trials: + Invoking the returned `PackedFunc` will have less overhead from dictionary lookups + than normally running through the VM. + + If the saved name is taken, it can be overridden, though it cannot override + the name of a function defined in the Relax source. + + This is really creating a closure, but the function has a different name + to avoid confusion with `invoke_closure` (they are not meant to be used together). + + Parameters + ---------- + func_name : str + The function that should be packaged up. + + saved_name : str + The name that the resulting closure should be saved under. + + include_return : bool + Whether the saved PackedFunc should return its output. + If timing over RPC, it may not be desirable to send output + between machines. + + args : List[Any] + The arguments to package up with the function. + + kwargs : Dict[str, Any] + Any named arguments to package up with the function + """ + cargs = [] + if kwargs: + args = self._convert_func_named_args(func_name, args, **kwargs) + for arg in args: + self._convert(arg, cargs) + self._save_function(func_name, saved_name, int(include_return), *cargs) + + def _convert(self, arg: Any, cargs: List) -> None: + """helper function to convert arguments to vm function.""" + + def _gettype(arg): + if isinstance(arg, np.float16): + return "float16" + elif isinstance(arg, (_base.integer_types, bool)): + return "int32" + else: + return "float32" + + if isinstance(arg, Object): + cargs.append(arg) + elif isinstance(arg, np.ndarray): + nd_arr = tvm.nd.array(arg, device=tvm.cpu(0)) + cargs.append(nd_arr) + elif isinstance(arg, tvm.runtime.NDArray): + cargs.append(arg) + elif isinstance(arg, (tuple, list)): + field_args = [] + for field in arg: + self._convert(field, field_args) + cargs.append(container.tuple_object(field_args)) + elif isinstance(arg, (_base.numeric_types, bool)): + dtype = _gettype(arg) + value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0)) + cargs.append(value) + elif isinstance(arg, str): + cargs.append(arg) + else: + raise TypeError("Unsupported type: %s" % (type(arg))) + + def _convert_func_named_args(self, func_name: str, args: Any, **kwargs: Any) -> Any: + """ + Takes named function parameters and returns a list of those needed, + in the order they should appear + """ + # kwargs can be a super set of the required function parameters. + # We only find the ones that are needed. + func_arity = self._get_function_arity(func_name) + func_params = [self._get_function_param_name(func_name, i) for i in range(func_arity)] + new_args = [None] * len(func_params) + cnt = 0 + for k in kwargs: + if k in func_params: + idx = func_params.index(k) + new_args[idx] = kwargs[k] + cnt += 1 + else: + print(f'Warning: Keyword argument "{k}" is unused in {func_name}') + assert len(args) + cnt == len(func_params) + idx = 0 + for i, arg in enumerate(new_args): + if arg is None: + new_args[i] = args[idx] + idx += 1 + return new_args + + def set_input(self, func_name: str, *args: Any, **kwargs: Any) -> None: + """Set the inputs to a function. + This interface works when using VM over RPC by internally converting NDArray in + the arguments to DLTensor, which is supported in RPC where remote could only + have a minimal C runtime. + + Note: If `set_input` is used, the function *must* be called using `invoke_stateful` + and the results must be obtained using `get_outputs`. + + Parameters + ---------- + func_name : str + The name of the function. + args: List[tvm.runtime.NDArray] or List[np.ndarray] + The arguments to the function. + kwargs: dict of str to tvm.runtime.NDArray or np.ndarray + Named arguments to the function. + """ + cargs = [] + + if kwargs: + args = self._convert_func_named_args(func_name, args, **kwargs) + + for arg in args: + self._convert(arg, cargs) + + self._set_input(func_name, *cargs) + + def invoke_stateful(self, func_name: str) -> None: + """ + Call the named function from the VM module using the arguments set using `set_input`. + It is an error to call `invoke_stateful` without using `set_input` first + (even if it's to set 0 inputs); conversely, if `set_input` has been called, + it is an error to call the function without using `invoke_stateful`. + + The results of the call can be obtained by calling `get_outputs`. + + Parameters + ---------- + func_name: str + The name of the function to call. + """ + self._invoke_stateful(func_name) + + def get_outputs(self, func_name: str) -> Union[tvm.Object, Tuple[Any]]: + """ + Get the value output by the function by the given name + after a call of `invoke_stateful`. + + It is an error to call this function without first calling `invoke_stateful`. + + Parameters + ---------- + func_name: str + The name of the function whose output should be fetched. + + Returns + ------- + ret: Union[tvm.Object, Tuple[Any]] + The result of the earlier call to the function via `invoke_stateful`. + If the result is a tuple, it returns a list of the fields. + The fields are potentially also tuples, so these can be arbitrily nested. + """ + # to deal with potentially nested tuples, we need to query for arity recursively + def get_output_rec(func_name, *idx): + arity = self._get_output_arity(func_name, *idx) + if arity == -1: + return self._get_output(func_name, *idx) + # otherwise we need to specify more indices + idx_list = list(idx) + return tuple(get_output_rec(func_name, *(idx_list + [i])) for i in range(arity)) + + return get_output_rec(func_name) + + def time_evaluator( + self, + func_name, + dev, + number=10, + repeat=1, + min_repeat_ms=0, + cooldown_interval_ms=0, + repeats_to_cooldown=1, + f_preproc="", + ) -> Callable[..., tvm.runtime.module.BenchmarkResult]: + """ + Returns an evaluator that times a function in the module. + This follows the same convention as time_evaluator in tvm.runtime.module. + This can be used in combination with save_function() so that the + timings avoid extra dictionary lookups. + + Parameters + ---------- + func_name: str + The name of the function in the module. + + dev: Device + The device we should run this function on. + + number: int + The number of times to run this function for taking average. + We call these runs as one `repeat` of measurement. + + repeat: int, optional + The number of times to repeat the measurement. + In total, the function will be invoked (1 + number x repeat) times, + where the first one is warm up and will be discarded. + The returned result contains `repeat` costs, + each of which is an average of `number` costs. + + min_repeat_ms: int, optional + The minimum duration of one `repeat` in milliseconds. + By default, one `repeat` contains `number` runs. If this parameter is set, + the parameters `number` will be dynamically adjusted to meet the + minimum duration requirement of one `repeat`. + i.e., When the run time of one `repeat` falls below this time, the `number` parameter + will be automatically increased. + + cooldown_interval_ms: int, optional + The cooldown interval in milliseconds between the number of repeats defined by + `repeats_to_cooldown`. + + repeats_to_cooldown: int, optional + The number of repeats before the cooldown is activated. + + f_preproc: str, optional + The preprocess function name we want to execute before executing the time evaluator. + + Note + ---- + The function will be invoked (1 + number x repeat) times, + with the first call discarded in case there is lazy initialization. + + Example + ------- + Normal use with a VM function (may not work over RPC if the function returns a tuple): + + .. code-block:: python + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestTimeEvaluator, target) + vm = relax.VirtualMachine(mod, tvm.cpu()) + timing_res = vm.time_evaluator("func_name", tvm.cpu())(arg0, arg1, ..., argn) + + Use with the stateful API: + + .. code-block:: python + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestTimeEvaluator, target) + vm = relax.VirtualMachine(mod, tvm.cpu()) + vm.set_input("func_name", arg0, arg1, ..., argn) + timing_res = vm.time_evaluator("invoke_stateful", tvm.cpu())("func_name") + + With saved closures via `save_function` (this results in + fewer dictionary lookups in the timed portion): + + .. code-block:: python + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestTimeEvaluator, target) + vm = relax.VirtualMachine(mod, tvm.cpu()) + vm.save_function("func_name", "func_name_saved", arg0, arg1, ..., argn) + timing_res = vm.time_evaluator("func_name_saved", tvm.cpu())() + + Returns + ------- + ftimer : function + The function that takes same argument as func and returns a BenchmarkResult. + The ProfileResult reports `repeat` time costs in seconds. + + """ + return self.module.time_evaluator( + func_name, + dev, + number=number, + repeat=repeat, + min_repeat_ms=min_repeat_ms, + cooldown_interval_ms=cooldown_interval_ms, + repeats_to_cooldown=repeats_to_cooldown, + f_preproc=f_preproc, + ) + + +def build( + mod: tvm.IRModule, + target: Union[str, tvm.target.Target], + params: Optional[Dict[str, list]] = None, +) -> Executable: + """ + Build an IRModule to VM executable. + + Parameters + ---------- + mod: IRModule + The input IRModule to be built. + + target : Union[str, tvm.target.Target] + A build target which can have optional host side compilation target. + + When TVM compiles device specific program such as CUDA, + we also need host(CPU) side code to interact with the driver + to setup the dimensions and parameters correctly. + host is used to specify the host side codegen target. + By default, llvm is used if it is enabled, + otherwise a stackvm interpreter is used. + + params: Optional[Dict[str, list]] + Parameters for the input IRModule that will be bound. + + Returns + ------- + ex: tvm.relax.vm.Executable + An executable that can be loaded by virtual machine. + + Example + ------- + + .. code-block:: python + class InputModule: + @R.function + def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): + z = R.add(x, y) + return z + + mod = InputModule + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + """ + if isinstance(target, str): + target = tvm.target.Target(target) + + passes = [relax.transform.ToNonDataflow()] + passes.append(relax.transform.CallTIRRewrite()) + passes.append(relax.transform.VMMemoryLower()) + passes.append(relax.transform.VMShapeLower()) + seq = tvm.transform.Sequential(passes) + new_mod = seq(mod) + + # Split primfunc and relax function + rx_mod, tir_mod = _split_tir_relax(new_mod) + lib = tvm.build(tir_mod, target=target) + + # Extract external runtime modules if exist. + ext_libs = [] + if mod.attrs and "external_mods" in mod.attrs: + ext_libs = mod.attrs["external_mods"] + + if params is None: + params = {} + + return Executable(_ffi_api.VMCodeGen(rx_mod, lib, ext_libs, target, params)) + + +def _split_tir_relax(mod: tvm.IRModule) -> Tuple[tvm.IRModule, tvm.IRModule]: + rx_mod = IRModule({}) + tir_mod = IRModule({}) + for gv in mod.get_global_vars(): + if isinstance(mod[gv], PrimFunc): + tir_mod[gv] = mod[gv] + elif isinstance(mod[gv], relax.Function): + rx_mod[gv] = mod[gv] + else: + raise TypeError( + "IRModule is expected to contain PrimFunc or Function, but gets {}".format( + type(mod[gv]) + ) + ) + return rx_mod, tir_mod diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 97842738e5..966d700300 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -70,6 +70,7 @@ # Span Span = base.Span SourceName = base.SourceName +Id = base.Id # Type Type = ty.Type diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 323a8f6e5a..c8f5fe6ad1 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -21,6 +21,7 @@ from tvm.runtime import Object from tvm.ir import SourceName, Span, Node as RelayNode +from . import _ffi_api __STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") @@ -37,5 +38,5 @@ class Id(Object): Guaranteed to be stable across all passes. """ - def __init__(self): - raise RuntimeError("Cannot directly construct Id") + def __init__(self, string): + self.__init_handle_by_constructor__(_ffi_api.Id, string) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index fefc285723..0f763ba4fc 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -23,7 +23,7 @@ import tvm._ffi from tvm._ffi import base as _base from tvm.runtime import NDArray, ndarray as _nd -from tvm.ir import RelayExpr, GlobalVar, Node +from tvm.ir import RelayExpr, GlobalVar, Node, Span from .base import RelayNode from . import _ffi_api @@ -301,8 +301,8 @@ class If(ExprWithOp): The expression evaluated when condition is false. """ - def __init__(self, cond, true_branch, false_branch): - self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch) + def __init__(self, cond, true_branch, false_branch, span: Span = None): + self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, false_branch, span) @tvm._ffi.register_object("relay.TupleGetItem") diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index e85b992341..071be8e428 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -493,7 +493,7 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No object_format = "cu" has_c_module = True else: - assert module.type_key == "llvm" or module.type_key == "static_library" + assert module.is_dso_exportable object_format = "o" path_obj = os.path.join(workspace_dir, f"lib{index}.{object_format}") module.save(path_obj) diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 555659d0c5..bedc4d3417 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -15,7 +15,25 @@ # specific language governing permissions and limitations # under the License. """TVM Script APIs of TVM Python Package, aimed to support TIR""" +from . import _parser, parser_v1 -from . import tir +############# +from ._parser import ir as ir_v2 +from ._parser import ir_module as ir_module_v2 +from ._parser import parse as from_source_v2 +from ._parser import relax as relax_v2 +from ._parser import tir as tir_v2 -from .parser import ir_module, from_source +############# +from .parser_v1 import from_source as from_source_v1 +from .parser_v1 import ir_module as ir_module_v1 +from .parser_v1 import relax as relax_v1 +from .parser_v1 import tir as tir_v1 + +# pylint: disable=invalid-name + +# ir = ir_v1 +ir_module = ir_module_v1 +tir = tir_v1 +relax = relax_v1 +from_source = from_source_v1 diff --git a/python/tvm/script/_parser/__init__.py b/python/tvm/script/_parser/__init__.py index d885b40525..678297799e 100644 --- a/python/tvm/script/_parser/__init__.py +++ b/python/tvm/script/_parser/__init__.py @@ -15,4 +15,8 @@ # specific language governing permissions and limitations # under the Licens. """The parser""" -from . import _core +from . import _core, ir, tir, relax +from ._core import parse +from .ir import ir_module +from .tir import prim_func +from .relax import function diff --git a/python/tvm/script/_parser/_core.py b/python/tvm/script/_parser/_core.py index a2dcc5b531..a1d307d648 100644 --- a/python/tvm/script/_parser/_core.py +++ b/python/tvm/script/_parser/_core.py @@ -13,7 +13,10 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# under the Licens. +# under the License. """The core parser infra""" # pylint: disable=unused-import -from .core import doc, utils +from .core import dispatch, doc, utils +from .core.dispatch import OpMethod, register_op +from .core.entry import parse +from .core.parser import Parser diff --git a/python/tvm/script/_parser/core/__init__.py b/python/tvm/script/_parser/core/__init__.py index ae1521006d..94d8dab032 100644 --- a/python/tvm/script/_parser/core/__init__.py +++ b/python/tvm/script/_parser/core/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """The core parser infra""" -from . import diagnostics, doc, doc_core, utils +from . import diagnostics, dispatch, doc, doc_core, entry, evaluator, parser, utils diff --git a/python/tvm/script/_parser/core/diagnostics.py b/python/tvm/script/_parser/core/diagnostics.py index b077d22142..f9baa1574c 100644 --- a/python/tvm/script/_parser/core/diagnostics.py +++ b/python/tvm/script/_parser/core/diagnostics.py @@ -204,7 +204,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) level : diagnostics.DiagnosticLevel The diagnostic level. """ - lineno = node.lineno or self.source.start_line + lineno = node.lineno or 1 col_offset = node.col_offset or self.source.start_column end_lineno = node.end_lineno or lineno end_col_offset = node.end_col_offset or col_offset diff --git a/python/tvm/script/_parser/core/dispatch.py b/python/tvm/script/_parser/core/dispatch.py new file mode 100644 index 0000000000..f10b90961a --- /dev/null +++ b/python/tvm/script/_parser/core/dispatch.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type + +from .doc import AST + +if TYPE_CHECKING: + from .parser import Parser + + +ParseMethod = Callable[["Parser", AST], None] +ParseVTable: Dict[Tuple[str, str], ParseMethod] = {} + +OpMethod = Callable[..., Any] +OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {} + + +def register(token: str, type_name: str): + """Register a method for a dispatch token and type name""" + + def f(method: ParseMethod): + ParseVTable[(token, type_name)] = method + + return f + + +def get( + token: str, + type_name: str, + default: Optional[ParseMethod] = None, +) -> Optional[ParseMethod]: + return ParseVTable.get((token, type_name), default) + + +def register_op(ty: Type, op: AST, operand_index: int): # pylint: disable=invalid-name + def f(method: OpMethod): + OpVTable[(ty, op, operand_index)] = method + + return f + + +def get_op( # pylint: disable=invalid-name + ty: Type, + op: Type, + operand_index: int, + default: Optional[OpMethod] = None, +) -> Optional[OpMethod]: + return OpVTable.get((ty, op, operand_index), default) diff --git a/python/tvm/script/_parser/core/entry.py b/python/tvm/script/_parser/core/entry.py new file mode 100644 index 0000000000..afd3cb5027 --- /dev/null +++ b/python/tvm/script/_parser/core/entry.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""The entry point of TVM parser.""" +from typing import Any, Union + +from ...ir_builder import IRBuilder +from . import doc +from .diagnostics import Source +from .parser import Parser + + +def parse(program: Union[doc.AST, Any, str], extra_vars=None): + if extra_vars is None: + from tvm.script._parser import ir # pylint: disable=import-outside-toplevel + from tvm.script._parser import relax # pylint: disable=import-outside-toplevel + from tvm.script._parser import tir # pylint: disable=import-outside-toplevel + + extra_vars = { + "I": ir, + "ir": ir, + "T": tir, + "tir": tir, + "relax": relax, + "R": relax, + } + + source = Source(program) + parser = Parser(source) + with IRBuilder() as builder: + parser.parse(extra_vars=extra_vars) + return builder.get() diff --git a/python/tvm/script/_parser/core/evaluator.py b/python/tvm/script/_parser/core/evaluator.py new file mode 100644 index 0000000000..0c2ccee48a --- /dev/null +++ b/python/tvm/script/_parser/core/evaluator.py @@ -0,0 +1,284 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""AST Evaluation""" +import ast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union + +from . import dispatch, doc + +if TYPE_CHECKING: + from .parser import Parser + +DEFAULT_OP: Dict[Type, Callable[..., Any]] = { + doc.Add: lambda a, b: a + b, + doc.Sub: lambda a, b: a - b, + doc.Mult: lambda a, b: a * b, + doc.Div: lambda a, b: a / b, + doc.FloorDiv: lambda a, b: a // b, + doc.Mod: lambda a, b: a % b, + doc.LShift: lambda a, b: a << b, + doc.RShift: lambda a, b: a >> b, + doc.BitOr: lambda a, b: a | b, + doc.BitXor: lambda a, b: a ^ b, + doc.BitAnd: lambda a, b: a & b, + doc.MatMult: lambda a, b: a @ b, + # fmt: off + doc.Pow: lambda a, b: a**b, + # fmt: on + doc.Eq: lambda a, b: a == b, + doc.NotEq: lambda a, b: a != b, + doc.Lt: lambda a, b: a < b, + doc.LtE: lambda a, b: a <= b, + doc.Gt: lambda a, b: a > b, + doc.GtE: lambda a, b: a >= b, + doc.Is: lambda a, b: a is b, + doc.IsNot: lambda a, b: a is not b, + doc.In: lambda a, b: a in b, + doc.NotIn: lambda a, b: a not in b, + doc.And: lambda a, b: a and b, + doc.Or: lambda a, b: a or b, + doc.Invert: lambda a: ~a, + doc.Not: lambda a: not a, + doc.UAdd: lambda a: +a, + doc.USub: lambda a: -a, +} + + +class ExprEvaluator: + + parser: "Parser" + value_table: Dict[str, Any] + new_value_count: int + + def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None: + super().__init__() + self.parser = parser + self.value_table = value_table + self.new_value_count = 0 + + @staticmethod + def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any: + self = ExprEvaluator(parser, value_table) + result = self._visit(node) # pylint: disable=protected-access + if isinstance(result, doc.Name): + if result.id not in self.value_table: + self.parser.report_error(result, f"Undefined variable: {result.id}") + return self.value_table[result.id] + if isinstance(result, doc.Constant): + return result.value + raise TypeError(f"Unexpected result type: {type(result)}") + + def _add_intermediate_result(self, value: Any) -> doc.Name: + name = f"__tvm_tmp_value_{self.new_value_count}" + self.new_value_count += 1 + self.value_table[name] = value + lineno = 0 + col_offset = 0 + return doc.Name( + id=name, + ctx=doc.Load( + lineno=lineno, + col_offset=col_offset, + end_lineno=None, + end_col_offset=None, + ), + lineno=lineno, + col_offset=col_offset, + end_lineno=None, + end_col_offset=None, + ) + + def _visit(self, node: doc.AST) -> Any: + if isinstance(node, list): + return [self._visit(n) for n in node] + if isinstance(node, tuple): + return tuple(self._visit(n) for n in node) + assert isinstance(node, doc.AST) + if isinstance(node, doc.Name): + if node.id not in self.value_table: + self.parser.report_error(node, f"Undefined variable: {node.id}") + return node + if isinstance( + node, + ( + doc.Constant, + doc.expr_context, + doc.operator, + doc.boolop, + doc.unaryop, + doc.cmpop, + ), + ): + return node + if not isinstance(node, (doc.expr, doc.slice)): + return node + if isinstance(node, doc.Lambda): + return self._eval_lambda(node) + fields = {} + for field in node.__class__._FIELDS: # pylint: disable=protected-access + attr = getattr(node, field) + if isinstance(attr, (doc.AST, tuple, list)): + fields[field] = self._visit(attr) + else: + fields[field] = attr + try: + if isinstance(node, doc.BoolOp): + value = self._eval_bool_op(fields) + elif isinstance(node, doc.Compare): + value = self._eval_compare(fields) + elif isinstance(node, doc.UnaryOp): + value = self._eval_unary_op(fields) + elif isinstance(node, doc.BinOp): + value = self._eval_bin_op(fields) + elif isinstance(node, doc.Slice): + value = self._eval_slice(fields) + else: + value = self._eval_expr(node.__class__(**fields)) + except Exception as e: # pylint: disable=broad-except,invalid-name + self.parser.report_error(node, str(e)) + return self._add_intermediate_result(value) + + def _eval_lambda(self, node: doc.Lambda) -> Any: + try: + value = self._eval_expr(node) + except Exception as e: # pylint: disable=broad-except,invalid-name + self.parser.report_error(node, str(e)) + return self._add_intermediate_result(value) + + def _eval_bool_op(self, fields: Dict[str, Any]) -> Any: + op = fields["op"] + if not isinstance(op, (doc.And, doc.Or)): + raise TypeError(f"Unexpected operator: {op}") + value = self._eval_expr(fields["values"][0]) + for rhs in fields["values"][1:]: + value = _eval_op(op, values=[value, self._eval_expr(rhs)]) + return value + + def _eval_compare(self, fields: Dict[str, Any]) -> Any: + value = self._eval_expr(fields["left"]) + for op, rhs in zip(fields["ops"], fields["comparators"]): + value = _eval_op(op, values=[value, self._eval_expr(rhs)]) + return value + + def _eval_unary_op(self, fields: Dict[str, Any]) -> Any: + value = self._eval_expr(fields["operand"]) + value = _eval_op(fields["op"], values=[value]) + return value + + def _eval_bin_op(self, fields: Dict[str, Any]) -> Any: + return _eval_op( + fields["op"], + values=[ + self._eval_expr(fields["left"]), + self._eval_expr(fields["right"]), + ], + ) + + def _eval_slice(self, fields: Dict[str, Any]) -> Any: + lower, upper, step = fields["lower"], fields["upper"], fields["step"] + + lower = self._eval_expr(lower) if lower is not None else None + upper = self._eval_expr(upper) if upper is not None else None + step = self._eval_expr(step) if step is not None else None + + return slice(lower, upper, step) + + def _eval_expr(self, v: Any) -> Any: + return _eval_expr(v, self.value_table) + + +def eval_expr( + parser: "Parser", + node: Union[doc.expr, doc.Expression], + dict_globals: Optional[Dict[str, Any]], +) -> Any: + value_table = {} + if dict_globals is not None: + value_table.update(dict_globals) + return ExprEvaluator.eval(parser, value_table, node) + + +def eval_assign( + parser: "Parser", + target: doc.expr, + source: Any, +) -> Dict[str, Any]: + try: + return _eval_assign(target, source) + except Exception as e: # pylint: disable=broad-except,invalid-name + parser.report_error(target, f"Failed to evaluate assignment: {str(e)}") + raise + + +def _eval_expr( + node: Union[doc.expr, doc.Expression], + dict_globals: Optional[Dict[str, Any]], +) -> Any: + node = doc.from_doc(node) + if isinstance(node, ast.expr): + node = ast.Expression(body=node) + assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node) + if dict_globals is None: + dict_globals = {} + node = ast.fix_missing_locations(node) + exe = compile(node, filename="", mode="eval") + return eval(exe, dict_globals) # pylint: disable=eval-used + + +def _eval_op( + op: doc.AST, + values: List[Any], +): + op_type = type(op) # pylint: disable=protected-access + for i, v in enumerate(values): + v_type = getattr(type(v), "_dispatch_type", None) + if v_type is None: + continue + f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None) + if f is not None: + return f(*values) + return DEFAULT_OP[op_type](*values) + + +def _eval_assign( + target: doc.expr, + source: Any, +) -> Dict[str, Any]: + target = doc.from_doc(target) + assert isinstance(target, ast.expr) + RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name + rhs_var_name = RHS_VAR_NAME + dict_locals = {rhs_var_name: source} + mod = ast.fix_missing_locations( + ast.Module( + body=[ + ast.Assign( + targets=[target], + value=ast.Name( + id=rhs_var_name, + ctx=ast.Load(), + ), + ) + ], + type_ignores=[], + ) + ) + exe = compile(mod, filename="", mode="exec") + exec(exe, {}, dict_locals) # pylint: disable=exec-used + del dict_locals[rhs_var_name] + return dict_locals diff --git a/python/tvm/script/_parser/core/parser.py b/python/tvm/script/_parser/core/parser.py new file mode 100644 index 0000000000..7846bd8c0f --- /dev/null +++ b/python/tvm/script/_parser/core/parser.py @@ -0,0 +1,300 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""The core parser""" +from collections import defaultdict +from contextlib import contextmanager +from typing import Any, Callable, Dict, List, Optional, Set, Union + +from tvm._ffi.base import TVMError +from tvm.error import DiagnosticError + +from . import dispatch, doc +from .diagnostics import Diagnostics, Source +from .evaluator import eval_assign, eval_expr + +DEFAULT_VISIT = { + "Interactive", + "Module", + "Expression", + "Pass", +} + + +def _deferred(f: Callable[[], None]): + @contextmanager + def context(): + try: + yield + finally: + f() + + return context() + + +class VarTableFrame: + vars: Set[str] + + def __init__(self): + self.vars = set() + + def add(self, var: str): + if var in self.vars: + raise ValueError(f"Variable {var} already defined in current scope") + self.vars.add(var) + + def pop_all(self, fn_pop: Callable[[str], None]): + for var in self.vars: + fn_pop(var) + self.vars.clear() + + +class VarTable: + + frames: List[VarTableFrame] + name2value: Dict[str, List[Any]] + + def __init__(self): + self.frames = [] + self.name2value = defaultdict(list) + + def with_frame(self): + def pop_frame(): + frame = self.frames.pop() + frame.pop_all(lambda name: self.name2value[name].pop()) + + self.frames.append(VarTableFrame()) + return _deferred(pop_frame) + + def add(self, var: str, value: Any, allow_shadowing: bool = False): + # Skip if the key and value are equal to those in the var_table + if self.name2value[var] and self.name2value[var][-1] == value: + return + if allow_shadowing and var in self.frames[-1].vars: + # Shadowing + self.name2value[var][-1] = value + else: + self.frames[-1].add(var) + self.name2value[var].append(value) + + def get(self) -> Dict[str, Any]: + return {key: values[-1] for key, values in self.name2value.items() if values} + + def exist(self, value: Any): + for v in self.name2value.values(): + if v is value: + return True + return False + + +def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod: + def _wrapper(self: "Parser", node: doc.AST) -> None: + try: + return func(self, node) + except DiagnosticError: + raise + except Exception as e: # pylint: disable=broad-except,invalid-name + self.report_error(node, e) + raise + + return _wrapper + + +def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod: + for token in [self.dispatch_tokens[-1], "default"]: + func = dispatch.get(token=token, type_name=type_name, default=None) + if func is not None: + return _dispatch_wrapper(func) + return _dispatch_wrapper(lambda self, node: self.generic_visit(node)) + + +def _dispatch_optional(self: "Parser", type_name: str) -> Optional[dispatch.ParseMethod]: + for token in [self.dispatch_tokens[-1], "default"]: + func = dispatch.get(token=token, type_name=type_name, default=None) + if func is not None: + return _dispatch_wrapper(func) + return None + + +class Parser(doc.NodeVisitor): + """The TVMScript parser""" + + diag: Diagnostics + dispatch_tokens: List[str] + var_table: VarTable + + def __init__(self, source: Source) -> None: + self.diag = Diagnostics(source) + self.dispatch_tokens = ["default"] + self.var_table = VarTable() + + def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: + if extra_vars is None: + extra_vars = {} + with self.var_table.with_frame(): + for k, v in extra_vars.items(): + self.var_table.add(k, v) + node = self.diag.source.as_ast() + self.visit(node) + + def with_dispatch_token(self, token: str): + def pop_token(): + self.dispatch_tokens.pop() + + self.dispatch_tokens.append(token) + return _deferred(pop_token) + + def eval_expr( + self, + node: Union[doc.Expression, doc.expr], + extra_vars: Optional[Dict[str, Any]] = None, + ) -> Any: + var_values = self.var_table.get() + if extra_vars is not None: + for k, v in extra_vars.items(): + var_values[k] = v + return eval_expr(self, node, var_values) + + def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]: + if isinstance(target, (doc.Tuple, doc.List)): + vars: Set[str] = set() # pylint: disable=redefined-builtin + for i in target.elts: + res = self._duplicate_lhs_check(i) + if isinstance(res, bool) and res: + return True + assert isinstance(res, set) + if vars & res: + return True + vars = vars.union(res) + return vars + elif isinstance(target, doc.Name): + return {target.id} + else: + self.report_error(target, "Invalid type in assign statement") + raise NotImplementedError + + def eval_assign( + self, + target: doc.expr, + source: Any, + bind_value: Callable[["Parser", doc.expr, str, Any], Any], + allow_shadowing: bool = False, + ) -> Dict[str, Any]: + if self._duplicate_lhs_check(target) is True: + self.report_error(target, "Duplicate vars assigned.") + var_values = eval_assign(self, target, source) + for k, v in var_values.items(): + var = bind_value(self, target, k, v) + self.var_table.add(k, var, allow_shadowing) + return var_values + + def report_error( + self, node: doc.AST, err: Union[Exception, str] + ) -> None: # pylint: disable=no-self-use + # Only take the last line of the error message + if isinstance(err, (TVMError, ValueError, TypeError)): + msg = list(filter(None, str(err).split("\n")))[-1] + else: + msg = str(err) + self.diag.error(node, msg) + + def visit(self, node: doc.AST) -> None: + if isinstance(node, (list, tuple)): + for item in node: + self.visit(item) + return + if not isinstance(node, doc.AST): + return + name = node.__class__.__name__.split(".")[-1] + if name in DEFAULT_VISIT: + func = self.generic_visit + else: + func = getattr(self, "visit_" + name, None) + if func is None: + raise NotImplementedError(f"Visitor of AST node is not implemented: {name}") + try: + func(node) + except DiagnosticError: + raise + except Exception as e: # pylint: disable=broad-except,invalid-name + self.report_error(node, str(e)) + raise + + def visit_body(self, node: List[doc.stmt]) -> Any: + for stmt in node: + self.visit(stmt) + + def visit_tvm_annotation(self, node: doc.expr) -> Any: + return _dispatch(self, "tvm_annotation")(self, node) + + def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + token = decorator.dispatch_token + func = dispatch.get(token=token, type_name="FunctionDef", default=None) + if func is None: + self.report_error(node, "The parser does not understand the decorator") + pre_func = _dispatch_optional(self, "pre_token_switch") + post_func = _dispatch_optional(self, "post_token_switch") + if pre_func: + pre_func(self, node) + _dispatch_wrapper(func)(self, node) + if post_func: + post_func(self, node) + + def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name + func = dispatch.get(token="ir", type_name="ClassDef", default=None) + if func is None: + self.report_error(node, "The parser does not understand the decorator") + _dispatch_wrapper(func)(self, node) + + def visit_arguments(self, node: doc.arguments) -> Any: + return _dispatch(self, "arguments")(self, node) + + def visit_For(self, node: doc.For) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "For")(self, node) + + def visit_While(self, node: doc.While) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "While")(self, node) + + def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "With")(self, node) + + def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "Assign")(self, node) + + def visit_Expr(self, node: doc.Expr) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "Expr")(self, node) + + def visit_If(self, node: doc.If) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "If")(self, node) + + def visit_AnnAssign(self, node: doc.AnnAssign) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "AnnAssign")(self, node) + + def visit_AugAssign(self, node: doc.AugAssign) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "AugAssign")(self, node) + + def visit_Assert(self, node: doc.Assert) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "Assert")(self, node) + + def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "Return")(self, node) diff --git a/python/tvm/script/_parser/ir/__init__.py b/python/tvm/script/_parser/ir/__init__.py new file mode 100644 index 0000000000..8cf9b50665 --- /dev/null +++ b/python/tvm/script/_parser/ir/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from . import parser as _parser +from .entry import ir_module, is_defined_in_class + +__all__ = ["ir_module", "is_defined_in_class"] diff --git a/python/tvm/script/_parser/ir/entry.py b/python/tvm/script/_parser/ir/entry.py new file mode 100644 index 0000000000..e0a0213cd1 --- /dev/null +++ b/python/tvm/script/_parser/ir/entry.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import inspect +from typing import Type + +from tvm.ir import IRModule + +from .._core import parse, utils + + +def is_defined_in_class(frames): + if len(frames) > 2: + maybe_class_frame = frames[2] + statement_list = maybe_class_frame[4] + if statement_list is None: + return False + first_statement = statement_list[0] + line = first_statement.strip() + if line.startswith("class "): + return True + if line.startswith("@") and "ir_module" in line: + return True + return False + + +def ir_module(f: Type) -> IRModule: + if not inspect.isclass(f): + raise TypeError(f"Expect a class, but got: {f}") + + return parse(f, utils.inspect_class_capture(f)) + + +setattr(ir_module, "dispatch_token", "ir") diff --git a/python/tvm/script/_parser/ir/parser.py b/python/tvm/script/_parser/ir/parser.py new file mode 100644 index 0000000000..eacbe9641c --- /dev/null +++ b/python/tvm/script/_parser/ir/parser.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from ...ir_builder import ir as I +from .._core import Parser, dispatch, doc + + +@dispatch.register(token="ir", type_name="ClassDef") +def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: + with self.var_table.with_frame(): + with I.ir_module(): + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + global_var = I.decl_function(stmt.name) + self.var_table.add(stmt.name, global_var) + with self.with_dispatch_token("ir"): + self.visit_body(node.body) + + +@dispatch.register(token="ir", type_name="Assign") +def _visit_assign(_self: Parser, _node: doc.Assign) -> None: + pass + + +@dispatch.register(token="ir", type_name="Expr") +def _visit_expr(_self: Parser, _node: doc.Expr) -> None: + pass diff --git a/python/tvm/script/_parser/relax/__init__.py b/python/tvm/script/_parser/relax/__init__.py new file mode 100644 index 0000000000..ed85bd8af6 --- /dev/null +++ b/python/tvm/script/_parser/relax/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from ...ir_builder.relax import * # pylint: disable=redefined-builtin +from ...ir_builder.relax import ir as _relax +from . import parser as _parser +from .entry import Callable, Tensor, function, match_shape + +__all__ = _relax.__all__ + ["Callable", "Tensor", "function", "match_shape"] diff --git a/python/tvm/script/_parser/relax/entry.py b/python/tvm/script/_parser/relax/entry.py new file mode 100644 index 0000000000..453afbaf17 --- /dev/null +++ b/python/tvm/script/_parser/relax/entry.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +import inspect +from typing import Callable as _Callable +from typing import List, Optional +from typing import TypeVar as _TypeVar +from typing import Union + +from tvm.ir import FuncType, TypeConstraint, TypeVar +from tvm.relax import Expr, Function, Type, Var +from tvm.tir import PrimExpr + +from ...ir_builder.relax import TensorType, tensor +from .._core import parse, utils +from ..ir import is_defined_in_class + +FType = _TypeVar("FType", bound=_Callable) + + +def function(f: FType) -> Union[Function, FType]: + if not inspect.isfunction(f): + raise TypeError(f"Expect a function, but got: {f}") + if is_defined_in_class(inspect.stack()): + return f + return parse(f, utils.inspect_function_capture(f)) + + +setattr(function, "dispatch_token", "relax") + + +class TensorProxy: + def __call__( + self, + shape: Optional[List[Union[PrimExpr, str]]] = None, + dtype: str = None, + ndim: int = -1, + ) -> TensorType: + return tensor(shape, dtype, ndim) + + def __getitem__(self, keys) -> Var: + return self(*keys) # pylint: disable=no-member # type: ignore + + +Tensor = TensorProxy() # pylint: disable=invalid-name + + +class CallableProxy: + """Function type. + + A function type consists of a list of type parameters to enable + the definition of generic functions, + a set of type constraints which we omit for the time being, + a sequence of argument types, and a return type. + + We can informally write them as: + `forall (type_params), (arg_types) -> ret_type where type_constraints` + + Parameters + ---------- + arg_types : List[Type] + The argument types + + ret_type : Type + The return type. + + type_params : Optional[List[TypeVar]] + The type parameters + + type_constraints : Optional[List[TypeConstraint]] + The type constraints. + """ + + def __call__( + self, + arg_types: List[Type], + ret_type: Type, + type_params: Optional[List[TypeVar]] = None, + type_constraints: Optional[List[TypeConstraint]] = None, + ) -> FuncType: + def _convert_type(ty: Union[Type, TensorType]) -> Type: + if isinstance(ty, TensorType): + return ty.type + elif isinstance(ty, Type): + return ty + else: + raise TypeError(f"Expect a Type or TensorType, but got: {ty}") + + arg_types = [_convert_type(ty) for ty in arg_types] + ret_type = _convert_type(ret_type) + return FuncType(arg_types, ret_type, type_params, type_constraints) + + def __getitem__(self, keys) -> Var: + return self(*keys) # pylint: disable=no-member # type: ignore + + +Callable = CallableProxy() + + +class MatchShapePair: + value: Expr + pattern: List[PrimExpr] + + def __init__(self, value: Expr, pattern: List[PrimExpr]) -> None: + self.value = value + self.pattern = pattern + + +def match_shape(value: Expr, pattern: List[PrimExpr]): + return MatchShapePair(value, pattern) diff --git a/python/tvm/script/_parser/relax/parser.py b/python/tvm/script/_parser/relax/parser.py new file mode 100644 index 0000000000..f8101a6e6c --- /dev/null +++ b/python/tvm/script/_parser/relax/parser.py @@ -0,0 +1,357 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +import contextlib +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, Union + +from tvm import relax, tir +from tvm.ir import Type +from tvm.script.ir_builder.relax.frame import BlockFrame + +from ...ir_builder import relax as R +from ...ir_builder.base import IRBuilder +from .._core import Parser, dispatch, doc +from .entry import MatchShapePair, Tensor, TensorType + + +class VarDefLoc: + def __init__(self, name: str, line: int, col: int): + self.name = name + self.line = line + self.col = col + + def __str__(self): + return f"{self.name}@{self.line}:{self.col}" + + def __repr__(self): + return f"{self.name}@{self.line}:{self.col}" + + +def collect_var_definitions(stmts: List[doc.stmt]) -> Dict[str, List[VarDefLoc]]: + class Collector(doc.NodeVisitor): + results: Dict[str, List[VarDefLoc]] + + def __init__(self): + self.results = defaultdict(list) + + def visit_Name(self, node: doc.Name): # pylint: disable=invalid-name + assert isinstance(node.ctx, doc.Store) + assert node.id + assert node.lineno is not None + assert node.col_offset is not None + self.results[node.id].append( + VarDefLoc( + node.id, + node.lineno, + node.col_offset, + ) + ) + + collector = Collector() + for stmt in stmts: + if isinstance(stmt, doc.Assign): + assert len(stmt.targets) == 1 + collector.visit(stmt.targets[0]) + elif isinstance(stmt, doc.AugAssign): + collector.visit(stmt.target) + + return collector.results + + +def bind_value_with_dataflow_var_names( + dataflow_var_names: List[str], var_def_table: Optional[Dict[str, List[VarDefLoc]]] +): + def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + var_table = self.var_table.get() + + if isinstance(value, tir.Var): + if value.name and var_name != value.name: + self.report_error( + node, + "Cannot define TIR variables with different names. The LHS of binding should " + "has the same name provided in RHS.", + ) + if var_name in var_table: + prev_value = var_table[var_name] + if not isinstance(prev_value, tir.Var): + self.report_error( + node, + "Cannot redefine a non-TIR-variable object to a TIR variable. Please " + "define the TIR variable with another name.", + ) + if prev_value.dtype != value.dtype: + self.report_error( + node, + "Expected the same dtype for TIR vars " + f"but got {value.dtype} vs {prev_value.dtype}", + ) + return prev_value + IRBuilder.name(var_name, value) + return value + + is_dataflow_var = False + if var_def_table is not None and ( + var_name not in dataflow_var_names or node.lineno != var_def_table[var_name][-1].line + ): + is_dataflow_var = True + + if isinstance(value, relax.Expr): + var = R.emit(value, is_dataflow_var) + # It's an internal check, so directly use assert here. + assert var is not None + IRBuilder.name(var_name, var) + return var + elif isinstance(value, MatchShapePair): + var = R.emit_match_shape( + value.value, value.pattern, emit_var=True, is_dataflow_var=is_dataflow_var + ) + # It's an internal check, so directly use assert here. + assert var is not None + IRBuilder.name(var_name, var) + return var + else: + raise TypeError(f"Unsupported type {type(value)} in assignment") + + return bind_assign_value + + +def eval_type_annotation(self: Parser, node: Union[doc.Expression, doc.expr]) -> Any: + type_annotation = self.eval_expr(node) + if callable(type_annotation): + type_annotation = Tensor() + if isinstance(type_annotation, TensorType): + shape = type_annotation.shape + if shape is None: + return type_annotation.type, None + shape = list(shape.values) + var_table = self.var_table.get() + for i, expr in enumerate(shape): + # Define the symbolic shape var + if isinstance(expr, tir.Var): + name = expr.name + if name in var_table: + shape[i] = var_table[name] + else: + self.var_table.add(name, shape[i]) + return type_annotation.type, relax.ShapeExpr(shape) + else: + if not isinstance(type_annotation, Type): + self.report_error(node, f"Unsupported type annotation {type(type_annotation)}") + return type_annotation, None + + +@dispatch.register(token="relax", type_name="FunctionDef") +def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + with self.var_table.with_frame(): + with R.function(): + R.func_name(node.name) + if node.returns is not None: + ann_type, _ = eval_type_annotation(self, node.returns) + R.func_ret_type(ann_type) + with self.with_dispatch_token("relax"): + self.visit(node.args) + self.visit_body(node.body) + + +@dispatch.register(token="relax", type_name="pre_token_switch") +def pre_token_switch(self: Parser, node: doc.Expr) -> None: # pylint: disable=unused-argument + ir_builder = IRBuilder() + ir_builder.__enter__() + + +@dispatch.register(token="relax", type_name="post_token_switch") +def post_token_switch(self: Parser, node: doc.Expr) -> None: + ir_builder = IRBuilder.current() + result = ir_builder.get() + ir_builder.__exit__(None, None, None) + var = R.emit(result, is_dataflow_var=False) + IRBuilder.name(node.name, var) + self.var_table.add(node.name, var, allow_shadowing=False) + + +@dispatch.register(token="relax", type_name="Expr") +def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: + value = self.eval_expr(node.value) + if isinstance(value, MatchShapePair): + R.emit_match_shape(value.value, value.pattern, emit_var=False, is_dataflow_var=False) + elif isinstance(value, tuple): + # Currently `res` must be the return value of `R.output`. In order to make these variables + # accessible to the bindings of following binding blocks, we should pop these variables into + # the variable table of one level higher. + for var_name in self.var_table.frames[-1].vars: + if self.var_table.name2value[var_name][-1] in value: + var = self.var_table.name2value[var_name][-1] + # Pop up the variable to the variable table one level higher. + if var_name in self.var_table.frames[-2].vars: + self.var_table.name2value[var_name][-2] = var + else: + self.var_table.frames[-2].add(var_name) + self.var_table.name2value[var_name].append(var) + elif value is not None: + self.report_error(node, f"Unsupported Expr stmt type {value}.") + + +@dispatch.register(token="relax", type_name="arguments") +def visit_arguments(self: Parser, node: doc.arguments) -> None: + arg: doc.arg + for arg in node.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_type, param_shape = self.visit_tvm_annotation(arg.annotation) + param = R.arg(arg.arg, param_type, param_shape) + + self.var_table.add(arg.arg, param) + + +@dispatch.register(token="relax", type_name="tvm_annotation") +def visit_tvm_annotation(self: Parser, node: doc.expr): + return eval_type_annotation(self, node) + + +@dispatch.register(token="relax", type_name="With") +def visit_with(self: Parser, node: doc.With) -> None: + # Currently only `with R.dataflow()` is supported + with contextlib.ExitStack() as stack: + stack.enter_context(self.var_table.with_frame()) + if len(node.items) != 1: + self.report_error(node, "Only one dataflow block is allowed") + for item in node.items: + frame = self.eval_expr(item.context_expr) + if not isinstance(frame, BlockFrame): + self.report_error( + item.context_expr, "Invalid context expression in the with-statement." + ) + stack.enter_context(frame) + if item.optional_vars is not None: + self.report_error( + item.context_expr, + "Relax syntax doesn't allow binding expressions in `with` to variables", + ) + + assert isinstance(node.body, list) + var_def_table = collect_var_definitions(node.body) + + if ( + not isinstance(node.body[-1], doc.Expr) + or not isinstance(node.body[-1].value, doc.Call) + or node.body[-1].value.func.attr != "output" + ): + self.report_error( + node.body[-1], + "Relax dataflow blocks must have output. However, the last statement inside a " + "dataflow block is not `R.output`. Please use `R.output` to specify the output of " + "the dataflow block.", + ) + + dataflow_var_names = [] + for arg in node.body[-1].value.args: + if not isinstance(arg, doc.Name): + self.report_error( + arg, + "The output of Relax dataflow blocks must be all variables. However, one of " + "the dataflow block output is not a variable. Please make sure all output are " + "variables.", + ) + dataflow_var_names.append(arg.id) + + for i in range(len(node.body) - 1): + if not isinstance(node.body[i], doc.Assign): + self.report_error( + node.body[i], + "One non-assign statement appears unexpectedly inside a dataflow block. Only " + "the last statement inside a dataflow block is an Expr. Please make sure this " + "statement appears at a correct position.", + ) + if len(node.body[i].targets) != 1: + self.report_error( + node.body[i], "Consequential assignments like 'a = b = c' are not supported." + ) + lhs = node.body[i].targets[0] + rhs = self.eval_expr(node.body[i].value) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=bind_value_with_dataflow_var_names(dataflow_var_names, var_def_table), + allow_shadowing=True, + ) + + self.visit(node.body[-1]) + + +@dispatch.register(token="relax", type_name="Assign") +def visit_assign(self: Parser, node: doc.Assign) -> None: + if len(node.targets) != 1: + self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") + lhs = node.targets[0] + rhs = self.eval_expr(node.value) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=bind_value_with_dataflow_var_names(dataflow_var_names=[], var_def_table=None), + allow_shadowing=True, + ) + + +@dispatch.register(token="relax", type_name="AnnAssign") +def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: + lhs = node.target + rhs = self.eval_expr(node.value) + ann_type, ann_shape = self.visit_tvm_annotation(node.annotation) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=bind_value_with_dataflow_var_names(dataflow_var_names=[], var_def_table=None), + allow_shadowing=True, + ) + var = self.var_table.get().get(lhs.id) + assert isinstance(var, relax.Var) + R.ir.annotate_type_shape(var, ann_type, ann_shape) + + +@dispatch.register(token="relax", type_name="Return") +def visit_return(self: Parser, node: doc.Assign) -> None: + value = self.eval_expr(node.value) + + if isinstance(value, relax.Expr): + R.func_ret_value(value) + elif isinstance(value, Tuple): + if all([isinstance(f, tir.PrimExpr) for f in value]): + R.func_ret_value(relax.ShapeExpr(value)) + elif any([isinstance(f, tir.PrimExpr) for f in value]): + self.report_error( + node, "Return types, with mixed PrimExpr and Relax Expr, is not supported." + ) + else: + R.func_ret_value(relax.Tuple(value)) + else: + self.report_error(node, f"Unsupported return value type {type(value)}.") + + +@dispatch.register(token="relax", type_name="If") +def visit_if(self: Parser, node: doc.If) -> None: + if node.orelse is None: + raise ValueError("Else statements are required for relax dialect.") + with R.If(self.eval_expr(node.test)) as if_frame: + with self.var_table.with_frame(): + with R.Then(): + self.visit_body(node.body) + with self.var_table.with_frame(): + with R.Else(): + self.visit_body(node.orelse) + self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True) diff --git a/python/tvm/script/_parser/tir/__init__.py b/python/tvm/script/_parser/tir/__init__.py new file mode 100644 index 0000000000..930764f73d --- /dev/null +++ b/python/tvm/script/_parser/tir/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from ...ir_builder.tir import * # pylint: disable=redefined-builtin +from ...ir_builder.tir import ir as _tir +from . import operation as _operation +from . import parser as _parser +from .entry import Buffer, Ptr, prim_func + +__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"] diff --git a/python/tvm/script/_parser/tir/entry.py b/python/tvm/script/_parser/tir/entry.py new file mode 100644 index 0000000000..07bd75f351 --- /dev/null +++ b/python/tvm/script/_parser/tir/entry.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import inspect +from typing import Callable, Union + +from tvm.tir import Buffer, PrimFunc + +from ...ir_builder.tir import buffer_decl, ptr +from .._core import parse, utils +from ..ir import is_defined_in_class + + +def prim_func(f: Callable) -> Union[PrimFunc, Callable]: + if not inspect.isfunction(f): + raise TypeError(f"Expect a function, but got: {f}") + if is_defined_in_class(inspect.stack()): + return f + return parse(f, utils.inspect_function_capture(f)) + + +setattr(prim_func, "dispatch_token", "tir") + + +class BufferProxy: + def __call__( + self, + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + ) -> Buffer: + return buffer_decl( + shape, + dtype=dtype, + data=data, + strides=strides, + elem_offset=elem_offset, + scope=scope, + align=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + ) + + def __getitem__(self, keys) -> Buffer: + if not isinstance(keys, tuple): + return self(keys) + if len(keys) >= 2 and not isinstance(keys[1], str): + return self(keys) + return self(*keys) # pylint: disable=no-member # type: ignore + + +class PtrProxy: + def __call__(self, dtype, storage_scope="global"): + if callable(dtype): + dtype = dtype().dtype + return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + + def __getitem__(self, keys): + if not isinstance(keys, tuple): + return self(keys) + return self(*keys) + + +Buffer = BufferProxy() # pylint: disable=invalid-name +Ptr = PtrProxy() # pylint: disable=invalid-name diff --git a/python/tvm/script/_parser/tir/operation.py b/python/tvm/script/_parser/tir/operation.py new file mode 100644 index 0000000000..87fb9406ae --- /dev/null +++ b/python/tvm/script/_parser/tir/operation.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import Type + +from tvm import tir +from tvm.tir import IntImm + +from .._core import OpMethod, doc, register_op + + +def _register_expr_op(ty: Type): # pylint: disable=invalid-name + ty._dispatch_type = ty # pylint: disable=protected-access + + def _and(a, b): + if isinstance(a, bool): + a = IntImm("bool", a) + if isinstance(b, bool): + b = IntImm("bool", b) + return tir.And(a, b) + + def _or(a, b): + if isinstance(a, bool): + a = IntImm("bool", a) + if isinstance(b, bool): + b = IntImm("bool", b) + return tir.Or(a, b) + + def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name + register_op(ty, op, i)(m) + + for i in [0, 1]: + # Case 1. binop + r(doc.Add, i, lambda a, b: a + b) + r(doc.Sub, i, lambda a, b: a - b) + r(doc.Mult, i, lambda a, b: a * b) + r(doc.Div, i, lambda a, b: a / b) + r(doc.FloorDiv, i, lambda a, b: a // b) + r(doc.Mod, i, lambda a, b: a % b) + r(doc.LShift, i, lambda a, b: a << b) + r(doc.RShift, i, lambda a, b: a >> b) + r(doc.BitOr, i, lambda a, b: a | b) + r(doc.BitXor, i, lambda a, b: a ^ b) + r(doc.BitAnd, i, lambda a, b: a & b) + # doc.MatMult <-- not implemented + # doc.Pow <-- not implemented + # Case 2. cmpop + r(doc.Eq, i, tir.EQ) + r(doc.NotEq, i, tir.NE) + r(doc.Lt, i, tir.LT) + r(doc.LtE, i, tir.LE) + r(doc.Gt, i, tir.GT) + r(doc.GtE, i, tir.GE) + # doc.Is <-- not implemented + # doc.IsNot <-- not implemented + # doc.In <-- not implemented + # doc.NotIn <-- not implemented + # Case 3. boolop + r(doc.And, i, _and) + r(doc.Or, i, _or) + for i in [0]: + # Case 4. unaryop + r(doc.Invert, i, lambda a: ~a) + r(doc.Not, i, tir.Not) + r(doc.UAdd, i, lambda a: +a) + r(doc.USub, i, lambda a: -a) + + +_register_expr_op(tir.PrimExpr) +_register_expr_op(tir.IterVar) diff --git a/python/tvm/script/_parser/tir/parser.py b/python/tvm/script/_parser/tir/parser.py new file mode 100644 index 0000000000..032555187f --- /dev/null +++ b/python/tvm/script/_parser/tir/parser.py @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import contextlib +from functools import partial +from typing import Any + +from tvm.ir import PrimType +from tvm.tir import Buffer, IterVar, PrimExpr, Var + +from ...ir_builder import tir as T +from ...ir_builder.base import IRBuilder +from ...ir_builder.base import IRBuilderFrame as Frame +from .._core import Parser, dispatch, doc + + +def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + if isinstance(value, (list, tuple)): + for i, v in enumerate(value): + bind_with_value(self, node, f"{var_name}_{i}", v) + return value + elif isinstance(value, (Buffer, Var)): + IRBuilder.name(var_name, value) + return value + else: + self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement") + raise NotImplementedError + + +def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + if isinstance(value, (list, tuple)): + for i, v in enumerate(value): + bind_with_value(self, node, f"{var_name}_{i}", v) + return value + elif isinstance(value, Var): + IRBuilder.name(var_name, value) + return value + else: + self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement") + raise NotImplementedError + + +def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any: + if isinstance(value, T.inline): + return value.value + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + bind_with_value(self, _node, f"{var_name}_{i}", v) + return value + elif isinstance(value, Frame): + value.add_callback(partial(value.__exit__, None, None, None)) + res = value.__enter__() + IRBuilder.name(var_name, res) + return res + elif isinstance(value, (Buffer, IterVar)) or ( + isinstance(value, Var) and not self.var_table.exist(value) + ): + IRBuilder.name(var_name, value) + return value + elif isinstance(value, PrimExpr): + var = T.var(value.dtype) + IRBuilder.name(var_name, var) + frame = T.let(var, value) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + return var + return value + + +@dispatch.register(token="tir", type_name="For") +def visit_for(self: Parser, node: doc.For) -> None: + for_frame = self.eval_expr(node.iter) + if not isinstance(for_frame, T.frame.ForFrame): + self.report_error( + node.iter, + "Expect the for loop to be one of the following: " + "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", + ) + with self.var_table.with_frame(): + with for_frame as iters: + self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="While") +def visit_while(self: Parser, node: doc.While) -> None: + with self.var_table.with_frame(): + cond = self.eval_expr(node.test) + with T.While(cond): + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="Assign") +def visit_assign(self: Parser, node: doc.Assign) -> None: + if len(node.targets) != 1: + self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") + lhs = node.targets[0] + rhs = self.eval_expr(node.value) + if isinstance(lhs, doc.Subscript): + if isinstance(lhs.slice, doc.Tuple): + indices = [] + for index in lhs.slice.elts: + indices.append(self.eval_expr(index)) + else: + indices = [self.eval_expr(lhs.slice)] + T.buffer_store(self.eval_expr(lhs.value), rhs, indices) + else: + self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) + + +@dispatch.register(token="tir", type_name="AugAssign") +def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None: + lhs_pos = ( + node.target.lineno, + node.target.col_offset, + node.target.end_lineno, + node.target.end_col_offset, + ) + rhs_pos = ( + node.value.lineno, + node.value.col_offset, + node.value.end_lineno, + node.value.end_col_offset, + ) + node.target.ctx = doc.Load(*lhs_pos) + with self.var_table.with_frame(): + lhs_name = "__tvm_tmp_value_aug_assign_lhs" + rhs_name = "__tvm_tmp_value_aug_assign_rhs" + lhs_expr = self.eval_expr(node.target) + rhs_expr = self.eval_expr(node.value) + self.var_table.add(lhs_name, lhs_expr) + self.var_table.add(rhs_name, rhs_expr) + op = doc.BinOp( + doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos), + node.op, + doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos), + *lhs_pos, + ) + rhs = self.eval_expr(op) + lhs = node.target + lhs.ctx = doc.Store(*lhs_pos) + if isinstance(lhs, doc.Subscript): + if isinstance(lhs.slice, doc.Tuple): + indices = [] + for index in lhs.slice.elts: + indices.append(self.eval_expr(index)) + else: + indices = [self.eval_expr(lhs.slice)] + T.buffer_store(self.eval_expr(lhs.value), rhs, indices) + else: + self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) + + +@dispatch.register(token="tir", type_name="AnnAssign") +def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: + lhs = node.target + rhs = self.eval_expr(node.value) + ann_var = self.visit_tvm_annotation(node.annotation) + if not isinstance(ann_var, Var): + self.report_error(node.annotation, "Annotation should be Var") + self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) + frame = T.let(ann_var, rhs) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + + +@dispatch.register(token="tir", type_name="With") +def visit_with(self: Parser, node: doc.With) -> None: + with contextlib.ExitStack() as stack: + stack.enter_context(self.var_table.with_frame()) + for item in node.items: + frame = self.eval_expr(item.context_expr) + if not isinstance(frame, Frame): + self.report_error( + item.context_expr, "Invalid context expression in the with-statement." + ) + rhs = stack.enter_context(frame) + if item.optional_vars is not None: + self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="FunctionDef") +def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + with self.var_table.with_frame(): + self.var_table.add("range", T.serial) + with T.prim_func(): + T.func_name(node.name) + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + T.func_ret(ret_type) + with self.with_dispatch_token("tir"): + self.visit(node.args) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="arguments") +def visit_arguments(self: Parser, node: doc.arguments) -> None: + # TODO: handle different types of arguments: + # - vararg: arg | None + # - kwonlyargs: list[arg] + # - kw_defaults: list[expr | None] + # - kwarg: arg | None + # - defaults: list[expr] + # - posonlyargs: list[arg] + arg: doc.arg + for arg in node.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation)) + self.var_table.add(arg.arg, param) + + +@dispatch.register(token="tir", type_name="tvm_annotation") +def visit_tvm_annotation(self: Parser, node: doc.expr): + annotation = self.eval_expr(node) + if callable(annotation): + annotation = annotation() + return annotation + + +@dispatch.register(token="tir", type_name="Expr") +def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: + res = self.eval_expr(node.value) + if isinstance(res, Frame): + res.add_callback(partial(res.__exit__, None, None, None)) + res.__enter__() + + +@dispatch.register(token="tir", type_name="If") +def visit_if(self: Parser, node: doc.If) -> None: + with self.var_table.with_frame(): + with T.If(self.eval_expr(node.test)): + with T.Then(): + self.visit_body(node.body) + if node.orelse: + with T.Else(): + self.visit_body(node.orelse) + + +@dispatch.register(token="tir", type_name="Assert") +def visit_assert(self: Parser, node: doc.Assert) -> None: + cond = self.eval_expr(node.test) + msg = self.eval_expr(node.msg) + frame = T.Assert(cond, msg) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + + +@dispatch.register(token="tir", type_name="Return") +def visit_return(self: Parser, node: doc.Return) -> None: + self.report_error(node, "Return is not allowed.") diff --git a/python/tvm/script/highlight.py b/python/tvm/script/highlight.py index dc45b5a3f1..29a8c10a0d 100644 --- a/python/tvm/script/highlight.py +++ b/python/tvm/script/highlight.py @@ -25,12 +25,12 @@ from tvm.tir import PrimFunc -def cprint(printable: Union[IRModule, PrimFunc, str], style: Optional[str] = None) -> None: +def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) -> None: """ Print highlighted TVM script string with Pygments Parameters ---------- - printable : Union[IRModule, PrimFunc, str] + printable : Union[IRModule, PrimFunc] The TVM script to be printed style : str, optional Printing style, auto-detected if None. @@ -44,8 +44,7 @@ def cprint(printable: Union[IRModule, PrimFunc, str], style: Optional[str] = Non installing the Pygment library. Other Pygment styles can be found in https://pygments.org/styles/ """ - if isinstance(printable, (IRModule, PrimFunc)): - printable = printable.script() + try: # pylint: disable=import-outside-toplevel import pygments @@ -69,7 +68,7 @@ def cprint(printable: Union[IRModule, PrimFunc, str], style: Optional[str] = Non + install_cmd, category=UserWarning, ) - print(printable) + print(printable.script()) else: class JupyterLight(Style): @@ -79,8 +78,8 @@ class JupyterLight(Style): styles = { Keyword: "bold #008000", Keyword.Type: "nobold #008000", - Name.Function: "#0000FF", - Name.Class: "bold #0000FF", + Name.Function: "#1E90FF", + Name.Class: "bold #1E90FF", Name.Decorator: "#AA22FF", String: "#BA2121", Number: "#008000", @@ -144,7 +143,7 @@ class AnsiTerminalDefault(Style): formatter = HtmlFormatter(style=JupyterLight) formatter.noclasses = True # inline styles - html = highlight(printable, Python3Lexer(), formatter) + html = highlight(printable.script(), Python3Lexer(), formatter) display(HTML(html)) else: - print(highlight(printable, Python3Lexer(), Terminal256Formatter(style=style))) + print(highlight(printable.script(), Python3Lexer(), Terminal256Formatter(style=style))) diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 7aa33ee49c..69f13b2145 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame": _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member return self - def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument - _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member + def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument + if exc_type is None and exc_value is None: + # Do not execute `FrameExit` if the with scope exits because of exceptions + _ffi_api.IRBuilderFrameExit(self) # pylint: disable=no-member # type: ignore def add_callback(self, callback: Callable[[], None]) -> None: """Add a callback method invoked when exiting the with-scope. diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index ebb9728737..946be263a7 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,4 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import ir_module +from .ir import decl_function, def_function, ir_module diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 213180463c..ac7d479e1a 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,9 +16,46 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from tvm.ir import BaseFunc, GlobalVar + from . import _ffi_api from .frame import IRModuleFrame def ir_module() -> IRModuleFrame: + """Start a ir_module frame. + Returns + ------- + frame: IRModuleFrame + The constructed frame. + """ return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_function(func_name: str) -> GlobalVar: + """Declare a Function without given the specific function implementation. + Parameters + ---------- + func_name : str + The function unique name. + Note + ---- + It is usually used in cross-function call. And we can specify the function by `DefFunction` + Returns + ------- + gv : GlobalVar + The corresponding GlobalVar. + """ + return _ffi_api.DeclFunction(func_name) # pylint: disable=no-member # type: ignore + + +def def_function(func_name: str, func: BaseFunc) -> None: + """Define the function which is declared before. + Parameters + ---------- + func_name : str + The function unique name. + func: BaseFunc + The given function implementation + """ + return _ffi_api.DefFunction(func_name, func) # pylint: disable=no-member # type: ignore diff --git a/python/tvm/script/ir_builder/relax/__init__.py b/python/tvm/script/ir_builder/relax/__init__.py new file mode 100644 index 0000000000..f0905acf34 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Package tvm.script.ir_builder.relax""" +from . import frame +from .ir import * # pylint: disable=wildcard-import,redefined-builtin diff --git a/python/tvm/script/ir_builder/relax/_ffi_api.py b/python/tvm/script/ir_builder/relax/_ffi_api.py new file mode 100644 index 0000000000..6e2098cf88 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.ir_builder.relax""" +import tvm._ffi + +tvm._ffi._init_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/frame.py b/python/tvm/script/ir_builder/relax/frame.py new file mode 100644 index 0000000000..97e181fbe4 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/frame.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""IR Builder Frame for Relax dialect""" +from tvm._ffi import register_object as _register_object + +from ..base import IRBuilderFrame + + +@_register_object("script.ir_builder.relax.RelaxFrame") +class RelaxFrame(IRBuilderFrame): + """The base ir_builder frame for the relax dialect.""" + + +@_register_object("script.ir_builder.relax.SeqExprFrame") +class SeqExprFrame(RelaxFrame): + ... + + +@_register_object("script.ir_builder.relax.FunctionFrame") +class FunctionFrame(SeqExprFrame): + """The ir_builder frame for the relax function.""" + + +@_register_object("script.ir_builder.relax.BlockFrame") +class BlockFrame(RelaxFrame): + """The ir_builder frame for relax binding blocks.""" + + +@_register_object("script.ir_builder.relax.IfFrame") +class IfFrame(RelaxFrame): + ... + + +@_register_object("script.ir_builder.relax.ThenFrame") +class ThenFrame(SeqExprFrame): + ... + + +@_register_object("script.ir_builder.relax.ElseFrame") +class ElseFrame(SeqExprFrame): + ... diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py new file mode 100644 index 0000000000..ed4c1c1cf8 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -0,0 +1,372 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, wrong-import-order +"""IRBuilder for Relax dialect""" + +from typing import Dict, List, Optional, Tuple, Union + +from tvm._ffi import register_object as _register_object +from tvm.ir import Attrs, Type +from tvm.relax import Call, Expr, ExternFunc, ShapeExpr, Var + +############################### Operators ############################### +from tvm.relax.op import ( + add, + builtin, + call_tir, + invoke_closure, + make_closure, + multiply, + print, + shape_of, + unique, +) +from tvm.relax.ty import ObjectType, ShapeType +from tvm.runtime import Object as tvm_Object +from tvm.tir import PrimExpr + +from ..tir import var as _tir_var +from . import _ffi_api, frame + +############################## Tensor Type ############################## + + +@_register_object("script.ir_builder.relax.TensorType") +class TensorType(tvm_Object): + """A temporary Tensor type for `R.Tensor` in ir_builder.""" + + +def tensor( + shape: Optional[List[Union[PrimExpr, str]]] = None, + dtype: Optional[str] = None, + ndim: int = -1, +): + """Helper function for `R.Tensor` in parser + Parameters + ---------- + shape: Optional[List[Union[PrimExpr, str]]] + The shape of the tensor. It's runtime dependent if `shape` is None. + dtype: Optional[str] + The element data type of the tensor. It's runtime dependent if `dtype` is None. + ndim: int + The number of dimensions of the tensor. It's runtime dependent if `ndim` is -1. + Returns + ------- + tensor_type: TensorType + The TensorType that is only used in ir_builder. + """ + + if shape is not None: + if not isinstance(shape, list): + shape = list(shape) + + for i, s in enumerate(shape): + if isinstance(s, str): + shape[i] = _tir_var("int64", s) + + return _ffi_api.Tensor(shape, dtype, ndim) # pylint: disable=no-member # type: ignore + + +############################## Other Types ############################## + +Object = ObjectType() # pylint: disable=invalid-name +Shape = ShapeType() # pylint: disable=invalid-name + +############################### Function ################################ + + +def function() -> frame.FunctionFrame: + """Start a function frame. + Returns + ------- + frame: FunctionFrame + The constructed function frame. + """ + return _ffi_api.Function() # pylint: disable=no-member # type: ignore + + +def arg(name: str, type: Union[Type, TensorType], shape: Optional[ShapeExpr] = None) -> Var: + """Add a parameter to the last function frame. + Parameters + ---------- + name: str + The name of the parameter. + type: Union[Type, TensorType] + The type of the parameter. It can be a typical TVM Type or a TensorType, + which contains both type and shape + shape: Optional[ShapeExpr] + The shape of the parameter. + Returns + ------- + var: Var + The created function parameter var. + """ + + if isinstance(type, TensorType): + if shape is not None: + raise ValueError("Cannot specify the shape if we use TensorType") + shape = type.shape + type = type.type + + return _ffi_api.Arg(name, type, shape) # pylint: disable=no-member # type: ignore + + +def func_name(name: str) -> None: + """Specify the name of the last function frame. + Parameters + ---------- + name: str + The function name. + """ + return _ffi_api.FuncName(name) # pylint: disable=no-member # type: ignore + + +def func_attr(attrs: Dict[str, tvm_Object]) -> None: + """Specify the attrs of the last function frame. + Parameters + ---------- + attrs: Dict[str, Object] + The function attrs. + """ + return _ffi_api.FuncAttrs(attrs) # pylint: disable=no-member # type: ignore + + +def func_ret_type(ret_type: Union[TensorType, Type]) -> None: + """Specify the return type of the last function frame. + Parameters + ---------- + ret_type: Union[TensorType, Type] + The function return type. + """ + if isinstance(ret_type, TensorType): + ret_type = ret_type.type + return _ffi_api.FuncRetType(ret_type) # pylint: disable=no-member # type: ignore + + +def func_ret_value(value: Expr) -> None: + """Specify the return value of the last function frame. + Parameters + ---------- + value: Expr + The function return value. + """ + return _ffi_api.FuncRetValue(value) # pylint: disable=no-member # type: ignore + + +############################# BindingBlock ############################## + + +def dataflow() -> frame.BlockFrame: + """Start a dataflow binding block frame. + Returns + ------- + frame: frame.BlockFrame + The created ir_builder Block frame. + """ + return _ffi_api.Dataflow() # pylint: disable=no-member # type: ignore + + +def output(*vars: Tuple[Var]) -> Tuple[Var]: + """Expose the dataflow block output variables as global ones. + Parameters + ---------- + vars: Tuple[Var] + The output variables of a dataflow block. + Returns + ------- + vars: Tuple[Var] + The output variables of a dataflow block. Return the input variables to parser side for + followup process + """ + _ffi_api.DataflowBlockOutput(vars) # pylint: disable=no-member # type: ignore + return vars + + +################################## Ops ################################# + + +def call_packed( + func: str, + *args: List[Expr], + attrs: Optional[Attrs] = None, + type_args: Optional[Union[TensorType, List[TensorType]]] = None, +) -> Call: + """Create a relax Call, which calls a packed function. + Parameters + ---------- + func: str + The name of extern function. + args : List[Expr] + The arguments. + attrs: Optional[Attrs] + The call attributes + type_args: Optional[Union[TensorType, List[TensorType]]] + List of Types + Returns + ------- + call: Call + The created Relax Call + """ + op = ExternFunc(func) + if type_args is None: + raise ValueError(f"R.call_packed is required to have type_args") + if isinstance(type_args, (TensorType, Type)): + type_args = [type_args] + elif isinstance(type_args, tuple): + type_args = list(type_args) + for i, argument in enumerate(type_args): + if isinstance(argument, TensorType): + type_args[i] = argument.type + elif isinstance(argument, Type): + type_args[i] = argument + else: + raise TypeError( + "call_packed `type_args` is expected to be list of TensorType/Type, " + f"but got {type(arg)}" + ) + + return Call(op, args, attrs=attrs, type_args=type_args) + + +############################### Bindings ############################### + + +def emit(value: Expr, is_dataflow_var: bool) -> Var: + """Emit a binding to the last binding block frame. + Parameters + ---------- + value: Expr + The right side value of the bindings to be emitted. + is_dataflow_var: bool + A boolean indicating if the emitted binding variable is a dataflow variable. + Returns + ------- + var: Var + The left side var of the emitted binding. + """ + return _ffi_api.Emit(value, is_dataflow_var) # pylint: disable=no-member # type: ignore + + +def emit_match_shape( + value: Expr, pattern: List[PrimExpr], emit_var: bool, is_dataflow_var: bool +) -> Optional[Var]: + """Emit a match_shape binding to the last binding block frame. + Parameters + ---------- + value: Expr + The value of the MatchShape to be emitted. + pattern: List[PrimExpr] + The pattern of the MatchShape to be emitted. + emit_var: bool + A boolean indicating if the MatchShape contains the emitted variable. + is_dataflow_var: bool + A boolean indicating if the emitted variable is a dataflow variable when `emit_var` is True. + When `emit_var` is False, the value of this flag will be ignored. + Returns + ------- + var: Optional[Var] + The emitted var if `emit_var` is True. Otherwise, return `None`. + """ + return _ffi_api.EmitMatchShape(value, pattern, emit_var, is_dataflow_var) # type: ignore + + +############################# Type Deduce ############################## + + +def annotate_type_shape(var: Var, anno_type: Type, anno_shape: ShapeExpr) -> None: + """Annotate and check the type of relax var. + Parameters + ---------- + var: Var + The input var to be annotated. + + anno_type: Type + The annotated type + + anno_shape: ShapeExpr + The annotated shape + + """ + _ffi_api.AnnotateTypeShape(var, anno_type, anno_shape) + + +def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name + """Create an if frame. + Parameters + ---------- + condition : Expr + The condition of if statement, executes the true branch if the condition is true, + otherwise jump into the false branch. + Returns + ------- + res : frame.IfFrame + The result IfFrame. + """ + return _ffi_api.If(condition) # pylint: disable=no-member # type: ignore + + +def Then() -> frame.ThenFrame: # pylint: disable=invalid-name + """Create a then frame. + Returns + ------- + res : frame.ThenFrame + The result ThenFrame. + """ + return _ffi_api.Then() # pylint: disable=no-member # type: ignore + + +def Else() -> frame.ElseFrame: # pylint: disable=invalid-name + """Create an else frame. + Returns + ------- + res : frame.ElseFrame + The result ElseFrame. + """ + return _ffi_api.Else() # pylint: disable=no-member # type: ignore + + +############################### Importer ############################### + +__all__ = [ + "Else", + "If", + "Object", + "Shape", + "TensorType", + "Then", + "add", + "arg", + "builtin", + "call_packed", + "call_tir", + "dataflow", + "emit", + "emit_match_shape", + "func_attr", + "func_name", + "func_ret_type", + "func_ret_value", + "function", + "invoke_closure", + "make_closure", + "multiply", + "output", + "print", + "unique", + "shape_of", + "tensor", +] diff --git a/python/tvm/script/parser_v1/__init__.py b/python/tvm/script/parser_v1/__init__.py new file mode 100644 index 0000000000..62279b46c1 --- /dev/null +++ b/python/tvm/script/parser_v1/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script APIs of TVM Python Package, aimed to support TIR""" + +from . import tir +from . import relax + +from .parser import ir_module, from_source diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/parser_v1/context_maintainer.py similarity index 100% rename from python/tvm/script/context_maintainer.py rename to python/tvm/script/parser_v1/context_maintainer.py diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/parser_v1/diagnostics.py similarity index 100% rename from python/tvm/script/diagnostics.py rename to python/tvm/script/parser_v1/diagnostics.py diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/parser_v1/meta_unparser.py similarity index 100% rename from python/tvm/script/meta_unparser.py rename to python/tvm/script/parser_v1/meta_unparser.py diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser_v1/parser.py similarity index 97% rename from python/tvm/script/parser.py rename to python/tvm/script/parser_v1/parser.py index c34aae2345..b8ec5d0dcc 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser_v1/parser.py @@ -20,35 +20,36 @@ different python versions. Synr also provides an error handling context that we use for error reporting. """ -# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except -import types +import functools +import inspect import json import operator -import inspect + +# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except, import-outside-toplevel +import types from typing import Any, Callable, Dict, List, Optional, Union -from synr import ast, Transformer, to_ast import tvm -from tvm import IRModule +from synr import Transformer, ast, to_ast +from tvm import IRModule, relax from tvm._ffi.base import TVMError from tvm.ir import GlobalVar from tvm.ir.function import BaseFunc from tvm.tir import buffer from tvm.tir.function import PrimFunc -from . import _ffi_api -from . import tir +from .. import _ffi_api +from . import tir from .context_maintainer import ContextMaintainer +from .diagnostics import TVMDiagnosticCtx from .meta_unparser import MetaUnparser from .registry import Registry -from .diagnostics import TVMDiagnosticCtx -from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting - +from .tir import ty from .tir.intrin import Intrin -from .tir.node import Slice, BufferSlice -from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler +from .tir.node import BufferSlice, Slice +from .tir.scope_handler import ForScopeHandler, ScopeHandler, WithScopeHandler from .tir.special_stmt import SpecialStmt -from .tir import ty +from .utils import call_with_error_reporting, synr_span_from_tvm, tvm_span_from_synr class CallArgumentReader(object): @@ -604,7 +605,7 @@ def transform_Assign(self, node): out = func(*args) except Exception as e: self.report_error( - "Error occurred when invoking the function " + "Error occured when invoking the function " + func.__name__ + ": \n" + str(e), @@ -906,13 +907,6 @@ def transform_Call(self, node): ) if node.func_name.name in self._unaryop_maker: rhs = self.transform(node.params[0]) - if node.func_name.name == ast.BuiltinOp.USub and isinstance( - node.params[0], ast.Constant - ): - # '-literal' should be parsed together for proper literal type inference - if not isinstance(rhs, (tvm.tir.IntImm, tvm.tir.FloatImm)): - self.report_error("The literal is illegal after -", node.params[0].span) - return tvm.tir.const(-rhs.value) return self._unaryop_maker[node.func_name.name]( rhs, span=tvm_span_from_synr(node.span) ) @@ -1371,7 +1365,7 @@ def from_source( raise TypeError("Only function definitions are supported.") -def ir_module(input_module: type) -> IRModule: +def ir_module(input_module=None, metadata=None) -> IRModule: """Decorate a python class as tvm IRModule. Parameters @@ -1379,14 +1373,33 @@ def ir_module(input_module: type) -> IRModule: input_module : type The python class to be parsed. + metadata : Optional[Union[str, DictAttrs]] + The metadata attributes to be parsed. + Returns ------- - output : IRModule + mod : IRModule The result IRModule. """ - if inspect.isclass(input_module): - func_dict = { - name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc) - } - return IRModule(func_dict) - raise TypeError("Only class definitions are supported.") + if metadata is not None: + from .relax.parser import RelaxTransformer as _RelaxTransformer + + _RelaxTransformer.update_meta(metadata) + + if input_module is None: + return functools.partial(ir_module, metadata=metadata) + + def _ir_module(input_module: type) -> IRModule: + if inspect.isclass(input_module): + func_dict = { + name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc) + } + mod = IRModule(func_dict, attrs=metadata) + mod = relax.transform.Normalize()(mod) + mod = relax.transform.ResolveGlobals()(mod) + # FIXME(@altanh): where is the source map? + return mod + + raise TypeError("Only class definitions are supported.") + + return _ir_module(input_module) diff --git a/python/tvm/script/registry.py b/python/tvm/script/parser_v1/registry.py similarity index 100% rename from python/tvm/script/registry.py rename to python/tvm/script/parser_v1/registry.py diff --git a/python/tvm/script/parser_v1/relax/__init__.py b/python/tvm/script/parser_v1/relax/__init__.py new file mode 100644 index 0000000000..a104b7ce92 --- /dev/null +++ b/python/tvm/script/parser_v1/relax/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax parsing support for TVM script.""" + +from .function import function +from . import parser diff --git a/python/tvm/script/parser_v1/relax/function.py b/python/tvm/script/parser_v1/relax/function.py new file mode 100644 index 0000000000..da861940b1 --- /dev/null +++ b/python/tvm/script/parser_v1/relax/function.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script Interface for Relax Functions""" +# pylint: disable=import-outside-toplevel + +import inspect +from typing import Callable +import functools + +from tvm.relax import Function + +from .parser import from_source + + +def function(input_func=None, metadata=None) -> Function: + """Decorate a Python function as a Relax function in TVM script. + + Parameters + ---------- + input_func : Callable + The function to be parsed. + + metadata : Optional[Union[str, DictAttrs]] + The meta_data attributes to be parsed. + + Returns + ------- + output : Function + The parsed Relax Function. + """ + if metadata is not None: + from .parser import RelaxTransformer as _RelaxTransformer + + _RelaxTransformer.update_meta(metadata) + + if input_func is None: + return functools.partial(function, metadata=metadata) + + def _function(input_func: Callable) -> Function: + if inspect.isfunction(input_func): + result = from_source(input_func) + result.__name__ = input_func.__name__ + result.__qualname__ = input_func.__qualname__ + return result + raise TypeError("Only function definitions are supported.") + + return _function(input_func) diff --git a/python/tvm/script/parser_v1/relax/parser.py b/python/tvm/script/parser_v1/relax/parser.py new file mode 100644 index 0000000000..976bdd55a8 --- /dev/null +++ b/python/tvm/script/parser_v1/relax/parser.py @@ -0,0 +1,1788 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-else-return, too-many-nested-blocks +# pylint: disable=inconsistent-return-statements, ungrouped-imports +# pylint: disable=arguments-differ +"""TVM Script Parser For Relax""" +from __future__ import annotations + +import inspect +import json +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import synr +import tvm +import tvm.script +from synr import Transformer, ast +from tvm import relax, relay, tir +from tvm.ir import diagnostics +from tvm.ir.module import IRModule +from tvm.relax.utils import metadata_partitioner + +from .. import relax as relax_namespace +from .. import tir as tir_namespace +from ..parser import TVMScriptParser as _TIRScriptParser +from ..tir.node import BufferSlice +from ..utils import call_with_error_reporting, tvm_span_from_synr + + +def _is_registered(op_name: str, op_set=None) -> bool: + """Returns whether or not the given operator is registered. + + Parameters + ---------- + op_name : str + The name of the operator. + op_set : Union[Container, Iterable], optional + The collection of registered operator names to check against. If None, the global TVM + operator registry is queried. + + Returns + ------- + bool + True if the specified operator is registered, else False. + """ + if op_set is None: + op_set = tvm.ir._ffi_api.ListOpNames() + return op_name in op_set + + +# NOTE: call_tir is an actual registered operator +class SpecialOp(Enum): + """Relax and TIR operators that have special semantics handled by the parser.""" + + MATCH_SHAPE = "relax.match_shape" + CALL_PACKED = "relax.call_packed" + DATAFLOW = "relax.dataflow" + DATAFLOW_OUTPUT = "relax.output" + TUPLE = "relax.Tuple" + TUPLE_GET_ITEM = "relax.TupleGetItem" + CONST = "relax.const" + CONSTANT = "relax.expr.Constant" + TIR_CAST = "tir.cast" + TIR_MAX = "tir.max" + + +class ArithmeticOp(Enum): + """Arithmetic operators that can desugar to either Relax or TIR PrimExpr operators.""" + + ADD = ast.BuiltinOp.Add + SUB = ast.BuiltinOp.Sub + MUL = ast.BuiltinOp.Mul + DIV = ast.BuiltinOp.Div + FLOOR_DIV = ast.BuiltinOp.FloorDiv + + +RELAX_ARITHMETIC_OP_MAP = { + ArithmeticOp.ADD: relay.op.get("add"), + ArithmeticOp.SUB: relay.op.get("subtract"), + ArithmeticOp.MUL: relay.op.get("multiply"), + ArithmeticOp.DIV: relay.op.get("divide"), + ArithmeticOp.FLOOR_DIV: relay.op.get("floor_divide"), +} + +PRIMEXPR_ARITHMETIC_OP_MAP = { + ArithmeticOp.ADD: tir.Add, + ArithmeticOp.SUB: tir.Sub, + ArithmeticOp.MUL: tir.Mul, + ArithmeticOp.DIV: tir.Div, + ArithmeticOp.FLOOR_DIV: tir.FloorDiv, +} + + +class RelaxTransformer(Transformer): + """A visitor to handle transformations on the Relax AST""" + + meta_attr = None + + def __init__(self, ir_mod: IRModule, relax_prefix: List[str], tir_prefix: List[str]): + super().__init__() + self.mod = ir_mod + self.relax_prefix = relax_prefix + self.tir_prefix = tir_prefix + self._scopes = [{}] # str -> Var + self._registered_ops = set(tvm.ir._ffi_api.ListOpNames()) # cached + + def to_tvm_span(self, span: ast.Span) -> tvm.ir.Span: + """Helper method for converting synr span to TVM span. + + Parameters + ---------- + span : ast.Span + The synr span + + Returns + ------- + tvm.ir.Span + The corresponding TVM span + """ + return self._diagnostic_context.to_tvm_span(span) + + def report_error(self, msg: str, span: ast.Span): + """Helper method for emitting and immediately rendering an error. + + Parameters + ---------- + msg : str + The error message + span : ast.Span + The span to report the error at + """ + self._diagnostic_context.emit("error", msg, self.to_tvm_span(span)) + self._diagnostic_context.render() + + def new_scope(self): + """Helper method for creating a new scope context object + + Returns + ------- + _Scope + An internal scope context object used in a with block to create a new scope + """ + + class _Scope: + def __init__(self, transformer: "RelaxTransformer"): + self.transformer = transformer + + def __enter__(self): + self.transformer._scopes.append(self.transformer._scopes[-1].copy()) + + def __exit__(self, *exc): + assert len(self.transformer._scopes) > 1, "cannot pop root scope" + self.transformer._scopes.pop() + + return _Scope(self) + + @classmethod + def update_meta(cls, metadata: str): + """Update the metadata attributes. + + Parameters + ---------- + metadata : str + The metadata to be parsed. + """ + + cls.meta_attr = metadata + + @classmethod + def get_meta(cls) -> str: + """Return the metadata attribute. + + Returns + ------- + str: + The metadata attributes + """ + return cls.meta_attr + + @property + def scope(self): + """Returns the current definition scope. + + Returns + ------- + Dict[str, Union[relax.Var, tir.Var]] + The scope of all currently defined variables (Relax and TIR). + """ + return self._scopes[-1] + + def decl_var( + self, + name: str, + type_annotation: Optional[relax.Type], + shape: Optional[relax.Expr], + span: ast.Span, + is_dataflow: bool = False, + ) -> relax.Var: + """Introduces a variable with the given name and annotations to the current scope. + + Parameters + ---------- + name : str + The name of the variable + type_annotation : Optional[relax.Type] + The type annotation + shape : Optional[relax.Expr] + The shape annotation + span : ast.Span + The span where the variable is declared + + Returns + ------- + Union[relax.Var, relax.DataflowVar] + The declared variable + """ + if name in self.scope: + # TODO(@altanh): maybe emit an error at the declaration site and report it together + self.report_error("variable has already been declared in the current scope", span) + if is_dataflow: + var = relax.DataflowVar(name, shape, type_annotation, self.to_tvm_span(span)) + else: + var = relax.Var(name, shape, type_annotation, self.to_tvm_span(span)) + self.scope[name] = var + return var + + def parse_tensor_kwargs_value( + self, ty: ast.Type, span=None + ) -> Union[str, None, bool, complex, float, int]: + """Parse value of Tensor annotation keyword parameters in synr ast and return value in + primitive type. + + Parameters + ---------- + ty : ast.Ty + The value of one of Tensor annotation keyword paramters + + Returns + ------- + Union[str, None, bool, complex, float, int] + Parsed value in primitive type + """ + + if isinstance(ty, ast.TypeConstant): + return ty.value + + if isinstance(ty, ast.TypeCall): + if ty.func_name == ast.BuiltinOp.UAdd: + assert len(ty.params) == 1 + + val = self.parse_tensor_kwargs_value(ty.params[0], span) + if not isinstance(val, int): + self.report_error(f"expected int, but got {type(val)}", span) + return val + + if ty.func_name == ast.BuiltinOp.USub: + assert len(ty.params) == 1 + + val = self.parse_tensor_kwargs_value(ty.params[0], span) + if not isinstance(val, int): + self.report_error(f"expected int, but got {type(val)}", span) + return 0 - val + self.report_error(f"unsupported op: {ty.func_name}", ty.span) + + self.report_error(f"unexpected value of keyword argument {ty}", ty.span) + + def parse_tensor_kwargs(self, ty: ast.Type) -> dict[str, int]: + """Parse keyword parameters of Tensor type annotation. + + Parameters + ---------- + ty : ast.Ty + The Tensor type annotation + + Returns + ------- + dict[str, int] + Parsed keyword parameters in dict[key, value] format + """ + kwargs = {} + for key, val in ty.keyword_params.items(): + assert isinstance(key, ast.TypeConstant) and isinstance(key.value, str) + kwargs[key.value] = self.parse_tensor_kwargs_value(val, ty.span) + + # sanity check for Tensor type keyword arguments + if len(kwargs) == 0: + return kwargs + if not (len(kwargs) == 1 and "ndim" in kwargs.keys()): + self.report_error( + f"expected one keyword argument 'ndim' but got {list(kwargs)}", ty.span + ) + if not isinstance(kwargs["ndim"], int): + self.report_error( + f"expcted 'ndim' to be of type int, but got {type(kwargs['ndim'])}", ty.span + ) + if kwargs["ndim"] < -1: + self.report_error(f"ndim must be >= -1, but got {kwargs['ndim']}", ty.span) + return kwargs + + def parse_dyn_tensor_type( + self, ty: Union[ast.Type, ast.Call], bind_free_vars: bool + ) -> Tuple[relax.Type, relax.Expr]: + """ + Transforms the given synr tensor type annotation to a Relax DynTensorType + Parameters + ---------- + ty : ast.Type or ast.Call + The synr type + bind_free_vars : bool + Whether or not the shape annotation can introduce new dimension variables + + Returns + ------- + Tuple[relax.Type, relax.Expr]: + The corresponding Relax type and shape expression + """ + + # TODO(@altanh): forgetting dtype like "Tensor((n, m))" ends up getting parsed as + # Tensor(n, m) which makes correct errors difficult here... + if len(ty.params) != 2: + self.report_error( + "Tensor type annotations must have 2 positional fields (shape and dtype)" + " and one optional keyword field ndim", + ty.span, + ) + + shape_annotation, dtype_annotation = ty.params + shape, dtype, ndim = None, None, -1 + + # parse the shape annotation + if isinstance(shape_annotation, ast.TypeConstant) and shape_annotation.value is None: + pass # shape = None + elif isinstance(shape_annotation, ast.TypeVar): + if shape_annotation.id.name != "_": + # TODO(@altanh): handle variable annotations, e.g. x: Tensor(my_shape, _) + self.report_error( + "variable Tensor shape annotations not yet supported", + shape_annotation.span, + ) + else: + shape = relax.RuntimeDepShape(span=self.to_tvm_span(shape_annotation.span)) + elif isinstance(shape_annotation, ast.TypeTuple): + shape = relax.ShapeExpr( + self.parse_shape(shape_annotation, bind_free_vars), + span=self.to_tvm_span(shape_annotation.span), + ) + ndim = len(shape) + elif isinstance(shape_annotation, ast.Tuple): + shape = relax.ShapeExpr( + self.parse_shape(shape_annotation, bind_free_vars), + span=self.to_tvm_span(shape_annotation.span), + ) + ndim = len(shape) + + else: + self.report_error( + f"unsupported shape annotation {shape_annotation}", + shape_annotation.span, + ) + + # parse the dtype annotation + if isinstance(dtype_annotation, ast.TypeVar) and dtype_annotation.id.name == "_": + pass # dtype = None + elif isinstance(dtype_annotation, ast.TypeConstant): + dtype = dtype_annotation.value + elif isinstance(dtype_annotation, ast.Constant): + dtype = dtype_annotation.value + else: + self.report_error( + "Tensor dtype annotations must be concrete or erased", + dtype_annotation.span, + ) + # parse optional keyword argument "ndim" if present + kwargs = self.parse_tensor_kwargs(ty) + if "ndim" in kwargs.keys(): + # If ndim was also inferred from shape annotation, then it must match keyword + # argument ndim. + if ndim >= 0 and kwargs["ndim"] != ndim: + self.report_error( + f"#shape dimensions must match ndim: {ndim} vs. {kwargs['ndim']}", + ty.span, + ) + else: + ndim = kwargs["ndim"] + span = self.to_tvm_span(ty.span) + return (relax.DynTensorType(ndim=ndim, dtype=dtype, span=span), shape) + + def transform_type(self, ty: ast.Type, bind_free_vars: bool) -> Tuple[relax.Type, relax.Expr]: + """Transforms the given synr type annotation to a Relax type and shape expression. + + Parameters + ---------- + ty : ast.Type + The synr type + bind_free_vars : bool + Whether or not the shape annotation can introduce new dimension variables + + Returns + ------- + Tuple[relax.Type, relax.Expr]: + The corresponding Relax type and shape expression + """ + if ty is None: + return (None, None) + + span = self.to_tvm_span(ty.span) + + # simple annotation with no type arguments + if isinstance(ty, ast.TypeVar): + if ty.id.name == "Tensor": + return (relax.DynTensorType(ndim=-1, dtype=None, span=span), None) + elif ty.id.name == "Shape": + return (relax.ShapeType(span), None) + elif ty.id.name == "Object": + return (relax.ObjectType(span), None) + elif ty.id.name == "Dim": + return (relax.DimType(span), None) + self.report_error("unknown type in annotation", ty.span) + + # annotation with type arguments/shape annotation + if isinstance(ty, ast.TypeCall): + if ty.func_name.id.name == "Tensor": + return self.parse_dyn_tensor_type(ty, bind_free_vars) + elif ty.func_name.id.name == "Tuple": + field_types = [] + field_shapes = [] + for field in ty.params: + fty, fsh = self.transform_type(field, bind_free_vars=False) + field_types.append(fty) + field_shapes.append(fsh) + return (relay.TupleType(field_types, span), None) + elif ty.func_name.id.name == "Callable": + if len(ty.params) != 2: + self.report_error( + "Function type annotations must have 2 positional fields", + ty.span, + ) + + func_arg_types, func_ret_type = ty.params + input_tensors = [] + # Single input + if isinstance(func_arg_types, ast.TypeCall): + tensor_type = self.parse_dyn_tensor_type(func_arg_types, bind_free_vars) + input_tensors.append(tensor_type[0]) + # Multiple inputs + elif isinstance(func_arg_types, ast.TypeTuple): + for func_arg_type in func_arg_types.values: + tensor_type = self.parse_dyn_tensor_type(func_arg_type, bind_free_vars) + input_tensors.append(tensor_type[0]) + else: + self.report_error( + "Function Reture Type annotations must be concrete or erased", + func_arg_types.span, + ) + + ret_type = self.transform_type(func_ret_type, bind_free_vars) + + return (relax.FuncType(input_tensors, ret_type[0]), None) + + self.report_error("invalid type", ty.span) + + def parse_shape( + self, + shape_annotation: Union[ast.TypeTuple, ast.Tuple], + bind_free_vars: bool, + ) -> List[tir.PrimExpr]: + """Parses the given shape annotation to a list of PrimExprs. + + Parameters + ---------- + shape_annotation : Union[ast.TypeTuple, ast.Tuple] + The shape annotation in synr + bind_free_vars : bool + Whether or not the annotation can bind previously free variables + + Returns + ------- + List[tir.PrimExpr] + The parsed shape as a list of PrimExprs + """ + return [self.parse_primexpr(field, bind_free_vars) for field in shape_annotation.values] + + def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr: + """Parses the given expression to a PrimExpr. + + Parameters + ---------- + expr : ast.Expr + The input expression + bind_free_vars : bool + Whether or not the expression can bind previously free variables + + Returns + ------- + tir.PrimExpr + The result PrimExpr + """ + if isinstance(expr, ast.Var): + var_name = expr.id.name + if var_name in self.scope: + var = self.scope[var_name] + if not isinstance(var, tir.Var): + # TODO(@altanh): we may wish to relax this in the future to support constructing + # shapes from Dim-typed Relax expressions + self.report_error( + "non-dimension variables cannot appear in dimension expressions", + expr.span, + ) + return var + elif bind_free_vars: + # introduce TIR variable to scope, e.g. for func params or relax.call_packed + var = tir.Var(var_name, "int64", self.to_tvm_span(expr.span)) + self.scope[var_name] = var + return var + else: + self.report_error( + "cannot introduce new dimension variables in this expression", + expr.span, + ) + + elif isinstance(expr, ast.Constant): + if not isinstance(expr.value, int): + self.report_error("only integer constants are supported", expr.span) + return tir.const(expr.value, "int64", self.to_tvm_span(expr.span)) + + elif isinstance(expr, ast.Call): + if not isinstance(expr.func_name, ast.Op): + self.report_error( + "only built-in operators can be used in dimension expressions", + expr.func_name.span, + ) + op = PRIMEXPR_ARITHMETIC_OP_MAP[self.transform_expr(expr.func_name)] + # TODO(@altanh): it might not make sense to bind free variables + args = [self.parse_primexpr(arg, bind_free_vars) for arg in expr.params] + return op(*args, span=self.to_tvm_span(expr.span)) + + else: + self.report_error(f"unsupported dimension expression: {expr}", expr.span) + + def transform_module(self, mod: ast.Module) -> IRModule: + """Transforms the given synr Module to a Relax IRModule or Function. + + Parameters + ---------- + mod : ast.Module + The input synr Module + + Returns + ------- + Union[IRModule, Function] + The parsed Relax IRModule or Function + """ + if len(mod.funcs) != 1: + self.report_error( + "the input must be either a single function or a single class", mod.span + ) + + (root_func,) = mod.funcs.values() + + if isinstance(root_func, ast.Function): + return self.transform_function(root_func, is_global=True) + elif isinstance(root_func, ast.Class): + # add global vars to the root scope for resolving global function calls + for func_name in root_func.funcs: + self.scope[func_name] = relay.GlobalVar(func_name) + for func_name, func in root_func.funcs.items(): + global_var = self.scope[func_name] + self.mod[global_var] = self.transform_function(func, is_global=True) + + # TODO(@yuchen): temporarily make the the parser.from_source api also run + # ResolveGlobals pass to populate shape and checked type to be consitent + # with the behavior of directly parsing TVMScript + self.mod = relax.transform.Normalize()(self.mod) + self.mod = relax.transform.ResolveGlobals()(self.mod) + return self.mod + else: + self.report_error(f"unsupported input class: {root_func}", root_func.span) + + def _parse_attrs_to_str(self, expr: ast.Attr) -> str: + strs = [] + attr = expr + while isinstance(attr, ast.Attr): + strs.append(attr.field.name) + attr = attr.object + if not isinstance(attr, ast.Var): + self.report_error("unsupported attribute access", expr.span) + if attr.id.name in self.tir_prefix: + strs.append("tir") + elif attr.id.name in self.relax_prefix: + strs.append("relax") + else: + strs.append(attr.id.name) + result = ".".join(reversed(strs)) + return result + + def _get_lhs(self, stmt: ast.Assign) -> ast.Var: + if len(stmt.lhs) > 1: + self.report_error("currently we only support single variable assignments", stmt.span) + return stmt.lhs[0] + + def _tir_from_synr(self, synr_ast: ast.Node) -> tir.PrimFunc: + """Parses the given synr AST using the TVMScript parser to a PrimFunc. + + Parameters + ---------- + synr_ast : ast.Node + The synr AST to be parsed. + diag_ctx : _TIRDiagnosticCtx + The diagnostic context for TVMScript parser error reporting. + + Returns + ------- + tir.PrimFunc + The parsed TIR PrimFunc. + """ + # this behavior is assumed by the TIR parser + self._diagnostic_context._render_on_error = True + parser = _TIRScriptParser(synr_ast.span.start_line, self.tir_prefix, {}) + prim_func = parser.do_transform(synr_ast, self._diagnostic_context) + self._diagnostic_context._render_on_error = False + return prim_func + + def transform_function(self, func: ast.Function, is_global: bool = False) -> relax.Function: + """Transforms the given synr Function to a Relax Function. + + Parameters + ---------- + func : ast.Function + The input synr Function + is_global : bool, optional + Whether or not the input function is global/module-level, by default False + + Returns + ------- + relax.Function + The parsed Relax Function + """ + if len(func.decorators) != 1: + self.report_error( + "functions must be decorated as a Relax Function or TIR PrimFunc", func.span + ) + decorator_name = None + if isinstance(func.decorators[0], ast.Call): + decorator_name = self._parse_attrs_to_str(func.decorators[0].func_name) + else: + decorator_name = self._parse_attrs_to_str(func.decorators[0]) + + if decorator_name == "tir.prim_func": + return self._tir_from_synr(func) + + if decorator_name != "relax.function": + self.report_error( + "functions must be decorated as a Relax Function or TIR PrimFunc", func.span + ) + + with self.new_scope(): + params = [] + for param in func.params: + ty, shape = self.transform_type(param.ty, bind_free_vars=True) + param = self.decl_var(param.name, ty, shape, param.span) + params.append(param) + new_body = self.transform_block(func.body) + ret_type, _ = self.transform_type(func.ret_type, bind_free_vars=False) + + relax_func = relax.Function.create_unchecked( + params, + new_body, + ret_type, + # TODO: Parse ret shape + relax.RuntimeDepShape(), + attrs=None, + span=self.to_tvm_span(func.span), + ) + if is_global: + relax_func = relax_func.with_attr("global_symbol", func.name) + + return relax_func + + def is_match_shape(self, stmt: ast.Stmt) -> bool: + """Returns whether or not the given statement is a MatchShape binding. + + Parameters + ---------- + stmt : ast.Stmt + The statement to be parsed. + + Returns + ------- + bool + Whether or not the statement is a MatchShape binding. + """ + call_op = None + if isinstance(stmt, ast.UnassignedCall): + call_op = self.transform_expr(stmt.call.func_name) + elif isinstance(stmt, ast.Assign) and isinstance(stmt.rhs, ast.Call): + call_op = self.transform_expr(stmt.rhs) + return call_op == SpecialOp.MATCH_SHAPE + + def parse_binding(self, stmt: ast.Stmt, is_dataflow: bool = False) -> relax.Binding: + """Parses the input synr statement to the corresponding Relax binding. + + Parameters + ---------- + stmt : ast.Stmt + The input synr statement (either an assignment or a unassigned call) + is_dataflow : bool, optional + Whether or not the binding is in a dataflow block, by default False + + Returns + ------- + relax.Binding + The parsed Relax binding + """ + assert isinstance(stmt, (ast.Assign, ast.UnassignedCall)) + if self.is_match_shape(stmt): + return self.parse_shape_binding(stmt, is_dataflow=is_dataflow) + else: + assert isinstance(stmt, ast.Assign) + return self.parse_var_binding(stmt, is_dataflow=is_dataflow) + + def parse_shape_binding(self, stmt: ast.Stmt, is_dataflow: bool = False) -> relax.MatchShape: + """Parses the input synr statement to a Relax shape binding. + + Parameters + ---------- + stmt : ast.Stmt + The input synr statement + is_dataflow : bool, optional + Whether or not the bound variable (if any) is a dataflow variable, by default False + + Returns + ------- + relax.MatchShape + The parsed Relax shape binding + """ + var: ast.Var = None + call: ast.Call = None + + if isinstance(stmt, ast.UnassignedCall): + # case where only dimension variables are bound, e.g. `match_shape(x.shape, (n, m))` + call = stmt.call + else: + # case where the statement also binds a Relax variable to the value being matched + assert isinstance(stmt, ast.Assign) + var = self._get_lhs(stmt) + call = stmt.rhs + if not isinstance(var, ast.Var): + self.report_error("the left hand side of a binding must be a variable", stmt.span) + + op = self.transform_expr(call.func_name) + + assert op == SpecialOp.MATCH_SHAPE + if len(call.params) != 2: + self.report_error(op.value + " takes exactly two arguments", call.span) + + value, pattern = call.params + + value = self.transform_expr(value) + if not isinstance(pattern, ast.Tuple): + self.report_error(f"the pattern of a {op.value} call must be a tuple", pattern.span) + pattern = self.parse_shape(pattern, bind_free_vars=True) + + if var is not None: + # TODO(@altanh): keep or discard annotation? + ty, shape = self.transform_type(stmt.ty, bind_free_vars=False) + var = self.decl_var(var.id.name, ty, shape, var.span, is_dataflow=is_dataflow) + + return relax.MatchShape(value, pattern, var, self.to_tvm_span(stmt.span)) + + def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False) -> relax.VarBinding: + """Parses the input synr assignment to a Relax variable binding. + + Parameters + ---------- + stmt : ast.Assign + The input synr assignment + is_dataflow : bool, optional + Whether or not the bound variable is a dataflow variable, by default False + + Returns + ------- + relax.VarBinding + The parsed Relax variable binding + """ + var = self._get_lhs(stmt) + if isinstance(stmt.rhs, ast.Constant): + rhs = relax.const(stmt.rhs.value) + else: + rhs = self.transform_expr(stmt.rhs) + # an ExternFunc call comes from call_packed + bind_free_vars = isinstance(rhs, relay.Call) and isinstance(rhs.op, relax.ExternFunc) + ty, shape = self.transform_type(stmt.ty, bind_free_vars) + lhs = self.decl_var(var.id.name, ty, shape, var.span, is_dataflow=is_dataflow) + return relax.VarBinding(lhs, rhs, self.to_tvm_span(stmt.span)) + + # Stmts: + # - Assert: probably unsupported for now + # - Assign: VarBinding + # - For: ?? + # - If: IfThenElse, must check no empty false branch + # - Return: just the returned expression, must terminate blocks? (special case if-else) + # - UnassignedCall: match_shape + # - With: relax.dataflow + def transform_stmt( + self, stmt: ast.Stmt + ) -> Union[relax.Expr, relax.Binding, relax.DataflowBlock]: + """Transforms the given synr statement to the corresponding Relax node. + + Parameters + ---------- + stmt : ast.Stmt + The input synr statement + + Returns + ------- + Union[relax.Expr, relax.Binding, relax.DataflowBlock] + The parsed Relax node + """ + if isinstance(stmt, ast.Assign): + # dataflow bindings are handled separately in parse_dataflow + return self.parse_binding(stmt) + elif isinstance(stmt, ast.If): + # check branches are non-empty + if len(stmt.true.stmts) == 0 or len(stmt.false.stmts) == 0: + self.report_error("both branches of an if-else block must be non-empty", stmt.span) + true_assign = stmt.true.stmts[-1] + false_assign = stmt.false.stmts[-1] + + # check last statement in each branch lines up + if not isinstance(true_assign, ast.Assign) or not isinstance( + self._get_lhs(true_assign), ast.Var + ): + self.report_error( + "each branch of an if-else statement must end in a variable assignment", + true_assign.span, + ) + if not isinstance(false_assign, ast.Assign) or not isinstance( + self._get_lhs(false_assign), ast.Var + ): + self.report_error( + "each branch of an if-else statement must end in a variable assignment", + false_assign.span, + ) + union_span = ast.Span.union([true_assign.span, false_assign.span]) + if self._get_lhs(true_assign).id.name != self._get_lhs(false_assign).id.name: + self.report_error( + "the final assignment of both branches must have the same variable", + union_span, + ) + + var_name = self._get_lhs(true_assign).id.name + + # rewrite branches to have a return statement so the blocks properly parse to SeqExprs + true_block = synr.ast.Block( + span=stmt.true.span, + stmts=stmt.true.stmts[:-1] + + [synr.ast.Return(span=true_assign.span, value=true_assign.rhs)], + ) + false_block = synr.ast.Block( + span=stmt.false.span, + stmts=stmt.false.stmts[:-1] + + [synr.ast.Return(span=false_assign.span, value=false_assign.rhs)], + ) + + # parse the branches, build the final expression and binding + cond = self.transform_expr(stmt.condition) + with self.new_scope(): + true_branch = self.transform_block(true_block) + with self.new_scope(): + false_branch = self.transform_block(false_block) + # TODO(@altanh): the spans here are all sorts of messed up, not sure how to fix + ite_expr = relay.If(cond, true_branch, false_branch, self.to_tvm_span(stmt.span)) + # TODO(@altanh): type and shape of return var + var = self.decl_var(var_name, None, None, union_span) + return relax.VarBinding(var, ite_expr, self.to_tvm_span(union_span)) + + elif isinstance(stmt, ast.Return): + return self.transform_expr(stmt.value) + + elif isinstance(stmt, ast.UnassignedCall): + if self.transform_expr(stmt.call.func_name) == SpecialOp.MATCH_SHAPE: + return self.parse_shape_binding(stmt) + else: + self.report_error("the results of normal function calls must be bound", stmt.span) + + elif isinstance(stmt, ast.With): + if not isinstance(stmt.rhs, ast.Call): + self.report_error("unsupported with block", stmt.span) + + call = stmt.rhs + op = self.transform_expr(call.func_name) + + # TODO(@altanh): perhaps this ought to be more general + + if op == SpecialOp.DATAFLOW: + if len(call.params) > 0: + self.report_error( + "dataflow block constructor takes no arguments", + call.params[0].span, + ) + if len(stmt.lhs) > 0: + self.report_error( + "dataflow blocks don't bind any patterns", + stmt.lhs[0].span, + ) + return self.parse_dataflow(stmt.body) + else: + self.report_error("unsupported with block type", call.span) + + elif isinstance(stmt, ast.Function): + func = self.transform_function(stmt) + return func + + else: + self.report_error( + "unsupported statement", + stmt.span, + ) + + def parse_dataflow(self, block: ast.Block) -> relax.DataflowBlock: + """Parses the input synr block to a Relax dataflow block. + + Parameters + ---------- + block : ast.Block + The input synr block + + Returns + ------- + relax.DataflowBlock + The parsed Relax dataflow block + """ + assert len(block.stmts) > 0, "should never have an empty dataflow block" + bindings = [] + + with self.new_scope(): + # parse the output statement first to figure out which bindings assign normal Vars + output_stmt = block.stmts[-1] + output_var_names = set() + unbound_output_vars = {} + output_vars = [] + + if ( + isinstance(output_stmt, ast.UnassignedCall) + and self.transform_expr(output_stmt.call.func_name) == SpecialOp.DATAFLOW_OUTPUT + ): + for var in output_stmt.call.params: + if not isinstance(var, ast.Var): + self.report_error(f"dataflow block outputs must be variables", var.span) + output_var_names.add(var.id.name) + unbound_output_vars[var.id.name] = var + else: + self.report_error( + f"dataflow blocks must end with a {SpecialOp.DATAFLOW_OUTPUT.value} statement", + output_stmt.span, + ) + + # output variables are bound to normal (not dataflow) Vars + for binding_stmt in block.stmts[:-1]: + if not isinstance(binding_stmt, (ast.Assign, ast.UnassignedCall)): + self.report_error( + "only bindings are supported in dataflow blocks", + binding_stmt.span, + ) + is_match_shape = self.is_match_shape(binding_stmt) + is_dataflow = ( + isinstance(binding_stmt, ast.Assign) + and self._get_lhs(binding_stmt).id.name not in output_var_names + ) + binding = self.parse_binding(binding_stmt, is_dataflow=is_dataflow) + bindings.append(binding) + if not is_dataflow: + if is_match_shape: + for var in binding.pattern: + output_vars.append(var) + if binding.var is not None: + output_vars.append(binding.var) + unbound_output_vars.pop(binding.var.name_hint) + + # check that the output variables are all bound locally + for unbound_var in unbound_output_vars.values(): + self._diagnostic_context.emit( + "error", + "dataflow output variables must be bound locally in the block", + unbound_var.span, + ) + # FIXME(@altanh): TVMDiagnosticCtx has hard-coded `emit` to always be an error and raise + # an exception on the first call + self._diagnostic_context.render() + + # make output variables visible in parent scope + for v in output_vars: + # v could already be in scope if it was a previously bound dimension variable + v_name = v.name if isinstance(v, tir.Var) else v.name_hint + if v not in self.scope: + self.scope[v_name] = v + + return relax.DataflowBlock(bindings, self.to_tvm_span(block.span)) + + def parse_attr(self, expr: ast.Attr) -> relax.Expr: + """Parses the given synr Attr node to a Relax expression. + + Parameters + ---------- + expr : ast.Attr + The synr Attr node to be parsed. + + Returns + ------- + relax.Expr + The parsed expression. + """ + if expr.field.name == "shape": + obj = self.transform_expr(expr.object) + return relay.Call( + relay.op.get("relax.shape_of"), [obj], span=self.to_tvm_span(expr.span) + ) + else: + # assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_tir) + op_name = self._parse_attrs_to_str(expr) + # NOTE: at least for now, all special operators are namespaced + try: + return SpecialOp(op_name) + except ValueError: + # TODO(@altanh): maybe diagnostics here in case this fails? + return relay.op.get(op_name) + + def parse_array_literal( + self, expr: ast.ArrayLiteral + ) -> Union[relax.const, relax.expr.Constant]: + """Parses the given synr ArrayLiteral node to a Relax constant. + + Parameters + ---------- + expr : ast.ArrayLiteral + The synr ArrayLiteral to be parsed. + + Returns + ------- + Union[relax.const, relax.expr.Constant] + The parsed relex expression. + """ + + def _get_values(expr: ast.ArrayLiteral, vals: List[Any]) -> List[Any]: + # todo(@yongwww): the generic parsing util for ArrayLiteral should be in synr + if isinstance(expr, ast.Constant): + vals.append(expr.value) + elif isinstance(expr, ast.ArrayLiteral): + for elem in expr.values: + # recursive call to get the nested list + nested_vals = _get_values(elem, []) + # avoid nested list for every element + if len(nested_vals) == 1 and not isinstance(nested_vals[0], list): + vals.append(nested_vals[0]) + else: + vals.append(nested_vals) + else: + self.report_error(f"unsupported ast expression {expr}", expr.span) + return vals + + const_values = _get_values(expr, []) + return relax.const(const_values) + + # TODO(@tvm-team): Currenly the synr is over-specialized, unify with transform_type + # to parse types in the future + def parse_type_from_value(self, val: ast.Expr) -> relax.Type: + """Parses the type_args value of a call to a Relax type. + + Parameters + ---------- + val : ast.Expr + The type_args value to be parsed. + + Returns + ------- + relax.Type + The parsed Relax type. + """ + if isinstance(val, ast.Var): + if val.id.name == "Tensor": + return relax.DynTensorType(ndim=-1, dtype=None, span=self.to_tvm_span(val.span)) + elif val.id.name == "Object": + return relax.ObjectType(self.to_tvm_span(val.span)) + elif val.id.name == "Shape": + return relax.ShapeType(self.to_tvm_span(val.span)) + elif val.id.name == "Void": + return relay.TupleType(None, self.to_tvm_span(val.span)) + else: + self.report_error( + f"type_args value must be Tensor, Object, Shape, Void, or Tuple()", val.span + ) + elif isinstance(val, ast.Call): + if val.func_name.id.name == "Tensor": + ndim = -1 + dtype = None + for k, v in val.keyword_params.items(): + if k.value == "ndim": + ndim = v.value + if k.value == "dtype": + dtype = v.value + return relax.DynTensorType(ndim, dtype, self.to_tvm_span(val.span)) + elif val.func_name.id.name == "Tuple": + field_types = [] + for field in val.params: + fty = self.parse_type_from_value(field) + field_types.append(fty) + return relax.TupleType(field_types, self.to_tvm_span(val.span)) + else: + self.report_error( + f"""type_args elements must be Tensor or Tuple when having arguments, + but meet {val.func_name.id.name}""", + val.span, + ) + else: + self.report_error( + f"cannot parse {val} as the type_args value", + val.span, + ) + + def parse_call_attr(self, expr: ast.Call) -> Tuple(tvm.ir.Attrs, List[relax.Type]): + """Parses keyword parameters as call attributes. + + Parameters + ---------- + expr : ast.Call + The synr Call to be parsed. + + Returns + ------- + Tuple(tvm.ir.Attrs, List[relax.Type]) + The parsed call attributes and type_args. + """ + op = self.transform_expr(expr.func_name) + kwargs = {} + type_args = None + for key, val in expr.keyword_params.items(): + if key.value == "type_args": + type_args = self.parse_type_from_value(val) + if type_args: + type_args = [type_args] + else: + assert isinstance(key, ast.Constant) and isinstance(key.value, str) + # TODO(@altanh): might need separate attribute parsing eventually + kwargs[key.value] = self.transform_expr(val) + + is_default = False + if "attrs_type_key" in kwargs: + attrs_type_key = kwargs["attrs_type_key"] + kwargs.pop("attrs_type_key") + elif isinstance(op, tvm.ir.Op) and op.attrs_type_key != "": + attrs_type_key = op.attrs_type_key + else: + attrs_type_key = "DictAttrs" + is_default = True + + attrs = None + if kwargs or not is_default: + attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs) + return (attrs, type_args) + + def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]: + """Parses the given synr Call node to a Relax expression or PrimExpr. + + Parameters + ---------- + expr : ast.Call + The synr Call node to be parsed. + + Returns + ------- + Union[tir.PrimExpr, relax.Expr] + The parsed expression. It will be a PrimExpr if expr is an arithmetic operation on + PrimExprs. + """ + if isinstance(expr.func_name, ast.Op) and expr.func_name.name == ast.BuiltinOp.Subscript: + if ( + hasattr(expr.params[0], "params") + and hasattr(expr.params[0].params[0], "id") + and expr.params[0].params[0].id.name == "meta" + ): + # Get the index of constant in b64ndarrays in metadata + const_idx = 0 + if hasattr(expr.params[-1], "values"): + const_idx = expr.params[-1].values[0].value + + if self.mod.get_attrs(): + metadata = self.mod.get_attrs() + else: + metadata = RelaxTransformer.get_meta() + + if not metadata: + self.report_error( + f"metadata is not found, please feed it into ir_module", expr.span + ) + + attr_json = json.loads(str(metadata)) + new_root = const_num = 0 + for i, node in enumerate(attr_json["nodes"]): + if "type_key" in node and "Constant" in node["type_key"]: + if const_num == const_idx: + new_root = i + break + const_num += 1 + attr_json["root"] = new_root + return tvm.ir.load_json(json.dumps(attr_json)) + else: + return self.transform_Subscript(expr) + + op = self.transform_expr(expr.func_name) + type_args = None + + if op == SpecialOp.CALL_PACKED: + extern_func = expr.params[0] + if not (isinstance(extern_func, ast.Constant) and isinstance(extern_func.value, str)): + self.report_error( + "the first argument of " + op.value + " must be the extern function name", + extern_func.span, + ) + op = relax.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span)) + args = [self.transform_expr(arg) for arg in expr.params[1:]] + + elif op == SpecialOp.TUPLE: + args = [self.transform_expr(arg) for arg in expr.params[0].values] + return relax.Tuple(args) + + elif op == SpecialOp.TUPLE_GET_ITEM: + assert len(expr.params) == 2, "TupleGetItem expects to get two parameters." + args = [self.transform_expr(arg) for arg in expr.params] + # index of TupleGetItem only accepts int type intead of tir.expr.IntImm + return relax.TupleGetItem(args[0], args[1].value) + + elif op in (SpecialOp.CONSTANT, SpecialOp.CONST): + # relax const/Constant + arg = expr.params[0] + if isinstance(arg, ast.Constant): + return relax.const(arg.value) + elif isinstance(arg, ast.ArrayLiteral): + return self.parse_array_literal(arg) + else: + self.report_error(f"unsupported ast for const: {arg}", expr.span) + + elif op == SpecialOp.TIR_CAST: + if len(expr.params) != 2: + self.report_error( + f"tir.cast expects 2 arguments, but got {len(expr.params)}", expr.span + ) + args = [self.transform_expr(arg) for arg in expr.params] + return tir.Cast(args[0], args[1]) + + elif op == SpecialOp.TIR_MAX: + if len(expr.params) != 2: + self.report_error( + f"tir.max expects 2 arguments, but got {len(expr.params)}", expr.span + ) + args = [self.transform_expr(arg) for arg in expr.params] + return tir.Max(args[0], args[1]) + + elif isinstance(op, ArithmeticOp): + args = [self.transform_expr(arg) for arg in expr.params] + if all([isinstance(arg, tir.PrimExpr) for arg in args]): + return PRIMEXPR_ARITHMETIC_OP_MAP[op](*args, span=self.to_tvm_span(expr.span)) + # otherwise it's just a normal Relax operator call + op = RELAX_ARITHMETIC_OP_MAP[op] + + elif isinstance(op, tvm.ir.Op): + args = [self.transform_expr(arg) for arg in expr.params] + # check call arity eagerly + if op.name == "relax.call_tir": + # call_tir is special case because last argument is optional + if len(args) != op.num_inputs and len(args) != op.num_inputs - 1: + self.report_error( + f"""{op.name} expects {op.num_inputs} or {op.num_inputs - 1} + arguments but got {len(args)}""", + expr.span, + ) + + if len(expr.keyword_params) != 1: + self.report_error( + f"""{op.name} expects exact one keyword argument with dtype as the key but + got {len(expr.keyword_params)} keyword arguments""", + expr.span, + ) + + if isinstance(args[0], str): + # extern function call case: rewrite identifier to an ExternFunc + args[0] = relax.ExternFunc(args[0], self.to_tvm_span(expr.params[1].span)) + + for key, val in expr.keyword_params.items(): + assert isinstance(key, ast.Constant) and isinstance(key.value, str) + if key.value == "dtype": + val = self.transform_expr(val) + # single output case + if isinstance(val, str): + if not isinstance(args[2], relax.ShapeExpr): + self.report_error( + ( + f"The number of output_shape and output_dtype of " + f"call_tir mismatch" + ), + expr.span, + ) + type_args = [relax.DynTensorType(ndim=len(args[2].values), dtype=val)] + elif isinstance(val, Tuple): + # multiple outputs case + if not isinstance(args[2], Tuple) and len(args[2]) != len(val): + self.report_error( + ( + f"The number of output_shape and output_dtype of " + f"call_tir mismatch" + ), + expr.span, + ) + types = [] + for i in range(len(args[2])): + types.append( + relax.DynTensorType(ndim=len(args[2][i].values), dtype=val[i]) + ) + type_args = [relax.TupleType(types)] + else: + self.report_error( + f"call_tir expects the output_dtype to be a string or a tuple", + expr.span, + ) + else: + self.report_error( + ( + f"{op.name} expects one keyword argument with dtype as the key but " + f"got {len(key.value)} as the key" + ), + expr.span, + ) + + elif op.num_inputs != -1 and len(args) != op.num_inputs: + self.report_error( + f"{op.name} expects {op.num_inputs} arguments but got {len(args)}", expr.span + ) + + elif isinstance(op, relay.Expr): + args = [self.transform_expr(arg) for arg in expr.params] + + else: + self.report_error(f"unsupported function in call: {op}", expr.func_name.span) + + if isinstance(op, tvm.ir.Op) and op.name == "relax.call_tir": + attrs = None + else: + attrs, type_args = self.parse_call_attr(expr) + + if isinstance(op, relax.ExternFunc) and type_args is None: + self.report_error(f"call_packed is required to have type_args", expr.span) + + return relax.Call( + op, args, attrs=attrs, type_args=type_args, span=self.to_tvm_span(expr.span) + ) + + # Exprs: + # - ArrayLiteral + # - Attr: use for .shape, and intrinsic/special operator namespace + # - Call + # - Constant + # - DictLiteral: unsupported for now + # - Slice: unsupported for now, could desugar to slice op + # - Tuple + # - Var + def transform_expr(self, expr: ast.Expr) -> relax.Expr: + """Transforms the input synr expression to a Relax expression. + + Parameters + ---------- + expr : ast.Expr + The input synr + + Returns + ------- + relax.Expr + The corresponding Relax expression + """ + + if isinstance(expr, ast.Attr): + return self.parse_attr(expr) + + elif isinstance(expr, ast.Call): + if hasattr(expr.func_name, "field") and expr.func_name.field.name == "match_shape": + return self.transform_expr(expr.func_name) + return self.parse_call(expr) + + elif isinstance(expr, ast.Tuple): + fields = [self.transform_expr(field) for field in expr.values] + + # Empty shape tuples should be treated as shape expressions. + if all([isinstance(f, str) for f in fields]) and len(fields) != 0: + return tuple(fields) + + # TODO(@altanh): this check might be too weak; we really only accept integral PrimExprs + # (e.g. int constants, dim vars, and integer operations on these) + + # coerce to ShapeExpr when fields are all PrimExprs + if all([isinstance(f, tir.PrimExpr) for f in fields]): + return relax.ShapeExpr(fields, span=self.to_tvm_span(expr.span)) + return relay.Tuple(fields, span=self.to_tvm_span(expr.span)) + + elif isinstance(expr, ast.Var): + var_name = expr.id.name + if _is_registered(var_name, op_set=self._registered_ops): + return relay.op.get(var_name) + if var_name in self.scope: + return self.scope[var_name] + # NOTE: this is a "hack" to get around Python eagerly parsing class method decorators + # first (meaning we need to resolve them after the functions are parsed). These + # GlobalVars need to be resolved using string equality only. + return relay.GlobalVar(var_name) + + elif isinstance(expr, ast.Constant): + # FIXME(@altanh): use internal representation that doesn't have precision limits here + if isinstance(expr.value, int): + return tir.IntImm("int64", expr.value, self.to_tvm_span(expr.span)) + elif isinstance(expr.value, float): + return tir.FloatImm("float32", expr.value, self.to_tvm_span(expr.span)) + elif isinstance(expr.value, str): + # FIXME(@altanh): using StringImm seems to cause problems, but this loses span + return expr.value + elif expr.value is None: + return None + else: + return relax.const(expr.value) + + elif isinstance(expr, ast.ArrayLiteral): + return self.parse_array_literal(expr) + + elif isinstance(expr, ast.Op): + # TODO(@altanh): might need to generalize from ArithmeticOp if we decide to support + # array slicing syntax + try: + return ArithmeticOp(expr.name) + except ValueError: + self.report_error(f"unsupported built-in operator: {expr.name}", expr.span) + else: + self.report_error(f"unsupported expression: {expr}", expr.span) + + def transform_Subscript(self, expr): + """Array access visitor.""" + + symbol = self.transform(expr.params[0]) + if symbol is None: + self.report_error( + f"Variable {expr.params[0].id.name} is not defined.", expr.params[0].span + ) + indexes = [self.transform(x) for x in expr.params[1].values] + if isinstance(symbol, relax.expr.Var): + if len(indexes) > 1: + self.report_error( + "Only a single index can be provided when indexing into a `var`.", + expr.params[1].span, + ) + index = indexes[0].value + if not isinstance(index, (tvm.tir.PrimExpr, int)): + self.report_error( + "Var load index should be an int or PrimExpr, but it is a" + type(index), + expr.span, + ) + return call_with_error_reporting( + self.report_error, + expr.span, + relax.TupleGetItem, + symbol, + index, + ) + elif isinstance(symbol, tvm.tir.expr.Var): + if symbol.dtype == "handle": + self.report_error( + "Cannot read directly from a handle, use `T.match_buffer` " + "to create a buffer to read from.", + expr.params[0].span, + ) + if len(indexes) > 1: + self.report_error( + "Only a single index can be provided when indexing into a `var`.", + expr.params[1].span, + ) + index = indexes[0] + if not isinstance(index, (tvm.tir.PrimExpr, int)): + self.report_error( + "Var load index should be an int or PrimExpr, but it is a" + type(index), + expr.span, + ) + + return call_with_error_reporting( + self.report_error, + expr.span, + tvm.tir.Load, + "float32", + symbol, + index, + True, + span=tvm_span_from_synr(expr.span), + ) + elif isinstance(symbol, tvm.tir.Buffer): + return BufferSlice( + symbol, indexes, self.report_error, span=tvm_span_from_synr(expr.span) + ) + elif isinstance(symbol, tvm.container.Array): + if len(indexes) > 1: + self.report_error( + "Array access should be one-dimension access, but the indices are " + + str(indexes), + expr.span, + ) + index = indexes[0] + if not isinstance(index, (int, tvm.tir.expr.IntImm)): + self.report_error( + "Array access index expected int or IntImm, but got " + type(index), + expr.span, + ) + if int(index) >= len(symbol): + self.report_error( + f"Array access out of bound, size: {len(symbol)}, got index {index}.", + expr.span, + ) + return symbol[int(index)] + else: + self.report_error( + f"Cannot subscript from a {type(symbol).__name__}.", + expr.params[0].span, + ) + + def transform_block(self, block: ast.Block) -> relax.SeqExpr: + """Transforms the given synr block to a Relax SeqExpr (sequence of Blocks with a final + expression). + + Parameters + ---------- + block : ast.Block + The input synr block + + Returns + ------- + relax.SeqExpr + The parsed SeqExpr + """ + # a block of statements needs to be converted to a SeqExpr of binding blocks + blocks = [] + current_block = [] + for stmt in block.stmts[:-1]: + parsed_stmt = self.transform_stmt(stmt) + if isinstance(parsed_stmt, relax.DataflowBlock): + if current_block: + # FIXME(@altanh): need to manually construct span start & end + blocks.append(relax.BindingBlock(current_block, self.to_tvm_span(stmt.span))) + current_block = [] + blocks.append(parsed_stmt) + elif isinstance(parsed_stmt, (relax.Function, tir.PrimFunc)): + func_var = self.decl_var(stmt.name, None, None, stmt.span) + current_block.append( + relax.VarBinding(func_var, parsed_stmt, self.to_tvm_span(stmt.span)) + ) + else: + assert isinstance( + parsed_stmt, relax.Binding + ), "Expected relax.Binding, but got " + str(type(parsed_stmt)) + current_block.append(parsed_stmt) + if len(current_block) > 0: + blocks.append(relax.BindingBlock(current_block, self.to_tvm_span(block.stmts[-1].span))) + + ret_stmt = block.stmts[-1] + if not isinstance(ret_stmt, ast.Return): + self.report_error( + "block must end with a returned expression", + ret_stmt.span, + ) + ret_expr = self.transform_stmt(ret_stmt) + + # only a call node in the function body + if isinstance(ret_expr, relax.Call) and len(blocks) == 0: + return ret_expr + + # return a defined inner function + if ( + len(blocks) > 0 + and isinstance(blocks[-1].bindings[-1].value, relax.Function) + and hasattr(ret_expr, "name_hint") + and ret_expr.name_hint == blocks[-1].bindings[-1].var.name_hint + ): + return blocks[-1].bindings[-1].value + + return relax.SeqExpr(blocks, ret_expr, self.to_tvm_span(block.span)) + + +class RelaxDiagnosticContext(synr.DiagnosticContext): + """Relax diagnostic context""" + + def __init__(self, ir_mod): + self.tvm_diag_ctx = diagnostics.DiagnosticContext(ir_mod, diagnostics.get_renderer()) + self.str_to_source_name = {} + self._render_on_error = False + + def to_tvm_span(self, ast_span: ast.Span) -> tvm.ir.Span: + return tvm.ir.Span( + self.str_to_source_name[ast_span.filename], + ast_span.start_line, + ast_span.end_line, + ast_span.start_column, + ast_span.end_column, + ) + + def add_source(self, name: str, source: str) -> None: + """Add a file with source code to the context. This will be called + before any call to :py:func:`emit` that contains a span in this + file. + """ + src_name = self.tvm_diag_ctx.module.source_map.add(name, source) + self.str_to_source_name[name] = src_name + + def emit(self, level: str, message: str, span: tvm.ir.Span) -> None: + """Called when an error has occured.""" + if isinstance(span, ast.Span): + span = self.to_tvm_span(span) + + if level == "error": + level = diagnostics.DiagnosticLevel.ERROR + elif level == "bug": + level = diagnostics.DiagnosticLevel.BUG + elif level == "warning": + level = diagnostics.DiagnosticLevel.WARNING + else: + level = "error" + + assert span, "Span must not be null" + assert isinstance(span, tvm.ir.Span), "Expected tvm.ir.Span, but got " + str(type(span)) + + diag = diagnostics.Diagnostic(level, span, message) + + self.tvm_diag_ctx.emit(diag) + if self._render_on_error: + self.render() + + def render(self) -> Optional[Any]: + """Render out all error messages. Can either return a value or raise + and execption. + """ + self.tvm_diag_ctx.render() + + +# def script(f) -> Union[relax.Function, Callable[[], tvm.IRModule]]: +# """Parses the decorated Relax function or module (in Relax IR) to a Relax AST. + +# Parameters +# ---------- +# f : Union[function, class] +# The function or class to be parsed, written in the Relax IR. + +# Returns +# ------- +# Union[relax.Function, IRModule] +# The parsed Relax function or IRModule factory (which returns the parsed IRModule when +# called). +# """ +# diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx() +# ast = synr.to_ast(f, diag_ctx) +# mod = RelaxTransformer().do_transform(ast, diag_ctx) +# if isinstance(mod, tvm.IRModule): +# return lambda: mod +# return mod + + +def from_source( + input_func: Union[str, Callable], + relax_prefix: Optional[List[str]] = None, + tir_prefix: Optional[List[str]] = None, +) -> Union[relax.Function, IRModule]: + """Parse function or string into a Relax Function or IRModule. + + If possible, pass the TVM script in as a function so that line numbers and + filename will be accurate. + + Parameters + ---------- + input_func : Union[str, Callable] + The python function to be parsed. + + relax_prefix : Optional[List[str]] + The relax prefix list. Only works for str input, default by "relax" and "R". + + tir_prefix : Optional[List[str]] + The tir prefix list. Only works for str input, default by "tir" and "T". + + Returns + ------- + output : Union[Function, IRModule] + The relax Function or IRModule. + """ + metadata = None + if isinstance(input_func, str) and "b64ndarrays" in input_func: + input_func, metadata = metadata_partitioner(input_func) + + mod = IRModule(attrs=metadata) + if isinstance(input_func, str): + relax_prefix = ["R", "relax"] if relax_prefix is None else relax_prefix + tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix + return synr.to_ast( + input_func, RelaxDiagnosticContext(mod), RelaxTransformer(mod, relax_prefix, tir_prefix) + ) + elif inspect.isfunction(input_func): + env: Dict[str, Any] = input_func.__globals__ + relax_prefix = [key for key in env.keys() if env[key] is relax_namespace] + tir_prefix = [key for key in env.keys() if env[key] is tir_namespace] + return synr.to_ast( + input_func, RelaxDiagnosticContext(mod), RelaxTransformer(mod, relax_prefix, tir_prefix) + ) + else: + raise TypeError("Only function definitions are supported.") + + +# def fromtext(source: str, source_name: str = "from_string") +# -> Union[relax.Function, tvm.IRModule]: +# """Parses the given input string (in the Relax text format) to a Relax AST. + +# Parameters +# ---------- +# source : str +# The input source string. It should be either a decorated Python class or function. +# source_name : str, optional +# A descriptive name for error reporting, by default "from_string". + +# Returns +# ------- +# Union[relax.Function, IRModule] +# The parsed Relax function or IRModule factory (which returns the parsed IRModule when +# called). +# """ +# # TODO(@altanh): actually use source_name somewhere? +# diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx() +# ast = synr.to_ast(source, diag_ctx) +# mod = RelaxTransformer().do_transform(ast, diag_ctx) +# if isinstance(mod, tvm.IRModule): +# return lambda: mod +# return mod + + +def pretty_print(node, show_meta_data=False): + """Prints the given Relax IR node in the Relax text format. + + Parameters + ---------- + node : Union[relax.Type, relax.Expr, relax.Binding, relax.BindingBlock] + The Relax IR node to print. + + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + """ + print(tvm.script._ffi_api.AsRelaxScript(node, show_meta_data)) + + +# TODO(@altanh): printer stuff should probably live elsewhere? +def astext(node, show_meta_data=False) -> str: + """Returns the Relax text format representation of the given Relax IR node. + + Parameters + ---------- + node : Union[relax.Type, relax.Expr, relax.Binding, relax.BindingBlock] + The Relax IR node to print. + + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + + Returns + ------- + relax_text: str + The text format representation of the given Relax IR node. + If show_meta_data is True, the meta data section will be printed in the beginning + of the the return string. + """ + return tvm.script._ffi_api.AsRelaxScript(node, show_meta_data) diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser_v1/tir/__init__.py similarity index 95% rename from python/tvm/script/tir/__init__.py rename to python/tvm/script/parser_v1/tir/__init__.py index d7db182f9d..2f2b4bbc25 100644 --- a/python/tvm/script/tir/__init__.py +++ b/python/tvm/script/parser_v1/tir/__init__.py @@ -18,7 +18,6 @@ # Type system from .ty import void, boolean, handle, Ptr, Tuple, Buffer -from .ty import bool # pylint: disable=redefined-builtin from .prim_func import prim_func diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/parser_v1/tir/intrin.py similarity index 94% rename from python/tvm/script/tir/intrin.py rename to python/tvm/script/parser_v1/tir/intrin.py index bd9aa1fdad..878240aa42 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/parser_v1/tir/intrin.py @@ -17,12 +17,12 @@ """TVM Script Parser Intrinsic Classes""" # pylint: disable=redefined-builtin, relative-beyond-top-level import builtins -from typing import List, Any +from typing import Any, List import tvm.tir -from tvm.tir import FloatImm + +from ....target import codegen from ..registry import register -from ...target import codegen from ..utils import get_param_list, tvm_span_from_synr @@ -52,9 +52,6 @@ def bool(imm, span): # nest closures so we copy the name string def wrap(name): def f(imm, span): - if name.startswith("float"): - if imm in {"inf", "-inf", "nan"}: - return FloatImm(dtype=name, value=float(imm), span=span) return imm.astype(name, span) f.__name__ = name @@ -89,11 +86,6 @@ def truncmod(x, y, span): return tvm.tir.truncmod(x, y, span) -@register -def truncdiv(x, y, span): - return tvm.tir.truncdiv(x, y, span) - - @register def ceildiv(x, y, span): return tvm.tir.ceildiv(x, y, span) diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/parser_v1/tir/node.py similarity index 100% rename from python/tvm/script/tir/node.py rename to python/tvm/script/parser_v1/tir/node.py diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser_v1/tir/prim_func.py similarity index 100% rename from python/tvm/script/tir/prim_func.py rename to python/tvm/script/parser_v1/tir/prim_func.py diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/parser_v1/tir/scope_handler.py similarity index 94% rename from python/tvm/script/tir/scope_handler.py rename to python/tvm/script/parser_v1/tir/scope_handler.py index 1d2550eecd..41fa6a5fa2 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/parser_v1/tir/scope_handler.py @@ -112,9 +112,9 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None) scope = tvm.runtime.convert(scope) return tvm.tir.Allocate( - self.buffer_var, - dtype, - extents, + self.buffer.data, + self.buffer.dtype, + self.buffer.shape, condition, self.body, annotations=annotations, @@ -122,7 +122,7 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None) ) super().__init__(allocate, concise_scope=True, def_symbol=True) - self.buffer_var = None + self.buffer = None def enter_scope( self, @@ -146,15 +146,20 @@ def enter_scope( else: raise Exception("Internal Bug") - def setup_buffer_var( + def setup_buffer( extents, dtype, scope, condition=True, annotations=None, span: Span = None ): - """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope) - self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) + """Setup buffer object for a given type.""" + self.buffer = tvm.tir.decl_buffer( + shape=extents, + dtype=dtype, + name=name, + scope=scope, + span=span, + ) - setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) - context.update_symbol(name, self.buffer_var, node) + setup_buffer(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer, node) @register @@ -171,7 +176,7 @@ def allocate_const(raw_data, dtype, shape, annotations=None, span=None): list_data.append(i.value) nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype)) n = tvm.tir.AllocateConst( - self.buffer_var, + self.buffer.data, dtype, shape, nd_data, @@ -182,7 +187,7 @@ def allocate_const(raw_data, dtype, shape, annotations=None, span=None): return n super().__init__(allocate_const, concise_scope=True, def_symbol=True) - self.buffer_var = None + self.buffer = None def enter_scope( self, @@ -206,13 +211,17 @@ def enter_scope( else: raise Exception("Internal Bug") - def setup_buffer_var(data, dtype, shape, annotations: dict = None, span: Span = None): + def setup_buffer(data, dtype, shape, annotations: dict = None, span: Span = None): """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) - self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) + self.buffer = tvm.tir.decl_buffer( + shape=shape, + dtype=dtype, + name=name, + span=span, + ) - setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) - context.update_symbol(name, self.buffer_var, node) + setup_buffer(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer, node) @register @@ -239,18 +248,7 @@ def decl_buffer( axis_separators=None, span=None, ): - decl_buffer = tvm.tir.DeclBuffer(self.buffer, self.body, span=span) - if data is None: - # when data is not specified, the buffer is implicitly allocated - return tvm.tir.Allocate( - self.buffer.data, - dtype, - shape, - tvm.runtime.convert(True), - decl_buffer, - span=span, - ) - return decl_buffer + return tvm.tir.DeclBuffer(self.buffer, self.body, span=span) super().__init__(decl_buffer, concise_scope=True, def_symbol=True) @@ -300,7 +298,6 @@ def setup_buffer( offset_factor=offset_factor, buffer_type=buffer_type, axis_separators=axis_separators, - name=name, span=span, ) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/parser_v1/tir/special_stmt.py similarity index 95% rename from python/tvm/script/tir/special_stmt.py rename to python/tvm/script/parser_v1/tir/special_stmt.py index 7cbf474410..15502055b7 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/parser_v1/tir/special_stmt.py @@ -121,8 +121,8 @@ class MatchBuffer(SpecialStmt): def __init__(self): def match_buffer( param, - shape=None, - dtype=None, + shape, + dtype="float32", data=None, strides=None, elem_offset=None, @@ -146,64 +146,28 @@ def match_buffer( offset_factor, "offset_factor", self.context.report_error, self.node.span ) buffer_name: str = self.node.lhs[0].id.name - + buffer = tvm.tir.decl_buffer( + shape, + dtype, + buffer_name, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + span=span, + ) if isinstance(param, tvm.tir.Var): - if shape is None: - self.context.report_error( - "Shape must be specified when binding input param", - self.node.rhs.span, - ) - - if dtype is None: - dtype = "float32" - - buffer = tvm.tir.decl_buffer( - shape, - dtype, - buffer_name, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - span=span, - ) if param not in self.context.func_params: self.context.report_error( "Can not bind non-input param to buffer", self.node.rhs.params[0].span ) self.context.func_buffer_map[param] = buffer - elif isinstance(param, BufferSlice): buffer_region = param.as_buffer_region() - - if shape is None: - shape = [dim.extent for dim in buffer_region.region] - - if dtype is None: - dtype = buffer_region.buffer.dtype - - if elem_offset is None and offset_factor == 0: - offset_factor = 1 - - buffer = tvm.tir.decl_buffer( - shape, - dtype, - buffer_name, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - span=span, - ) - self.context.current_block_scope().match_buffers.append( tvm.tir.MatchBufferRegion(buffer, buffer_region) ) diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/parser_v1/tir/ty.py similarity index 94% rename from python/tvm/script/tir/ty.py rename to python/tvm/script/parser_v1/tir/ty.py index b8323dd4a1..4548102a9e 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/parser_v1/tir/ty.py @@ -206,17 +206,7 @@ def __getitem__(self, args): _name = _dtype + _size + _lanes globals()[_name] = ConcreteType(_name) - -# All other DataType annotations are represented with the same string -# as is used by `tvm.runtime.DataType`. This does redefine the Python -# built-in bool, but only within the context of `tvm.script.tir.ty` -# and `tvm.script.tir` modules. The `T.boolean` alias is maintained -# for backwards compatibility. - -bool = ConcreteType("bool") # pylint: disable=redefined-builtin -boolean = bool - - +boolean = ConcreteType("bool") handle = ConcreteType("handle") void = VoidType() Ptr = GenericPtrType() diff --git a/python/tvm/script/utils.py b/python/tvm/script/parser_v1/utils.py similarity index 93% rename from python/tvm/script/utils.py rename to python/tvm/script/parser_v1/utils.py index c655a62237..a3d3892c9d 100644 --- a/python/tvm/script/utils.py +++ b/python/tvm/script/parser_v1/utils.py @@ -67,6 +67,9 @@ def get_param_list( def tvm_span_from_synr(span: synr.ast.Span) -> Span: """Convert a synr span to a TVM span""" + assert isinstance(span, synr.ast.Span), "Expected span to be synr.ast.Span, but got " + str( + type(span) + ) return Span( SourceName(span.filename), span.start_line, @@ -78,6 +81,9 @@ def tvm_span_from_synr(span: synr.ast.Span) -> Span: def synr_span_from_tvm(span: Span) -> synr.ast.Span: """Convert a TVM span to a synr span""" + assert isinstance(span, synr.ast.Span), "Expected span to be tvm.ir.Span, but got " + str( + type(span) + ) return synr.ast.Span( span.source_name.name, span.line, diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi deleted file mode 100644 index a64eed055a..0000000000 --- a/python/tvm/script/tir/__init__.pyi +++ /dev/null @@ -1,487 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=redefined-builtin -from typing import ( - Any, - Callable, - ContextManager, - Dict, - Iterable, - Optional, - Tuple, - Union, - Sequence, - List, - Mapping, - overload, -) -from numbers import Number -import builtins - -from tvm.tir.function import PrimFunc -from tvm.tir import Range -from tvm.runtime import Object -from tvm.target import Target -from .node import BufferSlice - -""" -redefine types -""" - -class PrimExpr: - def __init__(self: PrimExpr) -> None: ... - @overload - def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - def __mod__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ... - def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - def __floordiv__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ... - def __index__(self: PrimExpr) -> int: ... # so range doesn't complain - -class Var(PrimExpr): ... -class IterVar(Var): ... - -class Buffer: - @overload - def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]]) -> PrimExpr: ... - @overload - def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ... - @overload - def __setitem__( - self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]], value: PrimExpr - ) -> None: ... - @overload - def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ... - @property - def data(self: Buffer) -> Ptr: ... - -""" -Intrinsic -""" - -def min_value(dtype: str) -> PrimExpr: ... -def max_value(dtype: str) -> PrimExpr: ... -def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... -def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... -def ceildiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... -def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... -def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... -def abs(x: PrimExpr) -> PrimExpr: ... -def load( - dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None -) -> PrimExpr: ... -def cast(value: PrimExpr, dtype: str) -> PrimExpr: ... -def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ... -def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ... -def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ... -def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... -def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... -def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ... -def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ... -def evaluate(value: PrimExpr) -> None: ... -def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ... -def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ... -def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ... -def store( - var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True -) -> None: ... -def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ... -def llvm_lookup_intrinsic_id(name: str) -> PrimExpr: ... -def preflattened_buffer( - buf: Buffer, - shape: Sequence[PrimExpr], - dtype: str = "float32", - data: Optional[Ptr] = None, - strides: Optional[Sequence[int]] = None, - elem_offset: Optional[int] = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", -) -> Buffer: ... - -""" -Intrinsics - tvm builtin -""" - -def tvm_thread_allreduce( - *freduceargs: Union[PrimExpr, builtins.bool, Ptr], dtype: str -) -> PrimExpr: ... - -""" -Unary operator -Note that any intrinsics not registered in script.tir.intrin -should add "dtype" as an argument. This is different from their -definition but intentional. -""" - -def exp(x: PrimExpr, dtype: str) -> PrimExpr: ... -def exp2(x: PrimExpr, dtype: str) -> PrimExpr: ... -def exp10(x: PrimExpr, dtype: str) -> PrimExpr: ... -def erf(x: PrimExpr, dtype: str) -> PrimExpr: ... -def tanh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def sigmoid(x: PrimExpr, dtype: str) -> PrimExpr: ... -def log(x: PrimExpr, dtype: str) -> PrimExpr: ... -def log2(x: PrimExpr, dtype: str) -> PrimExpr: ... -def log10(x: PrimExpr, dtype: str) -> PrimExpr: ... -def log1p(x: PrimExpr, dtype: str) -> PrimExpr: ... -def tan(x: PrimExpr, dtype: str) -> PrimExpr: ... -def cos(x: PrimExpr, dtype: str) -> PrimExpr: ... -def cosh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def acos(x: PrimExpr, dtype: str) -> PrimExpr: ... -def acosh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def sin(x: PrimExpr, dtype: str) -> PrimExpr: ... -def sinh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def asin(x: PrimExpr, dtype: str) -> PrimExpr: ... -def asinh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def atan(x: PrimExpr, dtype: str) -> PrimExpr: ... -def atanh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def atan2(x: PrimExpr, dtype: str) -> PrimExpr: ... -def sqrt(x: PrimExpr, dtype: str) -> PrimExpr: ... -def rsqrt(x: PrimExpr, dtype: str) -> PrimExpr: ... - -""" -special_stmt - Buffers -""" - -def match_buffer( - param: Union[Var, BufferSlice], - shape: Sequence[Union[PrimExpr, int]], - dtype: str = "float32", - data: Var = None, - strides: Optional[Sequence[int]] = None, - elem_offset: Optional[int] = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", - axis_separators: Optional[List[int]] = None, -) -> Buffer: ... -def decl_buffer( - shape: Sequence[Union[PrimExpr, int]], - dtype: str = "float32", - data: Var = None, - strides: Optional[Sequence[int]] = None, - elem_offset: Optional[int] = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", - axis_separators: Optional[List[int]] = None, -) -> Buffer: ... -def buffer_decl( - shape: Sequence[Union[PrimExpr, int]], - dtype: str = "float32", - data: Var = None, - strides: Optional[Sequence[int]] = None, - elem_offset: Optional[int] = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", - axis_separators: Optional[List[int]] = None, -) -> Buffer: ... -def alloc_buffer( - shape: Sequence[Union[PrimExpr, int]], - dtype: str = "float32", - data: Var = None, - strides: Optional[Sequence[int]] = None, - elem_offset: Optional[int] = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", - axis_separators: Optional[List[int]] = None, -) -> Buffer: ... - -""" -special_stmt - Reads/Writes -""" - -@overload -def reads(read_regions: List[BufferSlice]) -> None: ... -@overload -def reads(*read_regions: BufferSlice) -> None: ... -@overload -def writes(write_region: List[BufferSlice]) -> None: ... -@overload -def writes(*write_region: BufferSlice) -> None: ... -def block_attr(attrs: Mapping[str, Object]) -> None: ... - -""" -special_stmt - Axis -""" - -class axis: - @overload - @staticmethod - def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def spatial( - dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr - ) -> IterVar: ... - @overload - @staticmethod - def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def reduce( - dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr - ) -> IterVar: ... - @overload - @staticmethod - def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def scan( - dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr - ) -> IterVar: ... - @overload - @staticmethod - def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def opaque( - dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr - ) -> IterVar: ... - @staticmethod - def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ... - -def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ... - -""" -special_stmt - Annotations -""" - -def buffer_var(dtype: str, storage_scope: str) -> Var: ... -def func_attr(attrs: Mapping[str, Union[Object, str, bool, int, float]]) -> None: ... -def prim_func(input_func: Callable) -> PrimFunc: ... - -""" -special_stmt - Threads and Bindings -""" - -def env_thread(env_name: str) -> IterVar: ... -def bind(iter_var: IterVar, expr: PrimExpr) -> None: ... - -""" -Scope handler -""" - -class block(ContextManager): - def __init__(self, name_hint: str = "") -> None: ... - def __enter__(self) -> Sequence[IterVar]: ... - -class init(ContextManager): - def __init__(self) -> None: ... - -class let(ContextManager): - def __init__(self, var: Var, value: PrimExpr) -> None: ... - -def where(cond: PrimExpr) -> None: ... -def allocate( - extents: List[PrimExpr], - dtype: str, - scope: str, - condition: Union[PrimExpr, builtins.bool] = True, - annotations: Optional[Mapping[str, Object]] = None, -) -> Buffer: ... -def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ... -def realize( - buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True -) -> None: ... -def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... -def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ... - -""" -Scope handler - Loops -""" - -@overload -def serial( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def serial( - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def parallel( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def parallel( - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def vectorized( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def vectorized( - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def unroll( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def unroll( - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def thread_binding( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - thread: str, - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def thread_binding( - end: Union[PrimExpr, int], - thread: str, - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def for_range( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def for_range( - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ... - -""" -ty - redefine types -""" - -class boolean: ... - -class handle(Var): - @overload - def __getitem__(self: handle, pos: Sequence[Union[int, PrimExpr, slice]]) -> Buffer: ... - @overload - def __getitem__(self: handle, pos: Union[int, PrimExpr, slice]) -> Buffer: ... - @overload - def __setitem__( - self: handle, pos: Sequence[Union[int, PrimExpr, slice]], value: Buffer - ) -> None: ... - @overload - def __setitem__(self: handle, pos: Union[int, PrimExpr, slice], value: Buffer) -> None: ... - @property - def data(self: handle) -> Ptr: ... - -class Ptr: ... - -def target(target_str: Union[str, Mapping[str, Object]]) -> Target: ... - -class var(Var): - def __init__(self: Var, dtype: str): ... - -class bool(PrimExpr): - def __init__(self: bool, imm: Union[PrimExpr, builtins.bool, builtins.int]): ... - -class int8(PrimExpr): - def __init__(self: int8, imm: Union[PrimExpr, int]): ... - -class int16(PrimExpr): - def __init__(self: int16, imm: Union[PrimExpr, int]): ... - -class int32(PrimExpr): - def __init__(self: int32, imm: Union[PrimExpr, int]): ... - -class int64(PrimExpr): - def __init__(self: int64, imm: Union[PrimExpr, int]): ... - -class uint8(PrimExpr): - def __init__(self: uint8, imm: Union[PrimExpr, int]): ... - -class uint16(PrimExpr): - def __init__(self: uint16, imm: Union[PrimExpr, int]): ... - -class uint32(PrimExpr): - def __init__(self: uint32, imm: Union[PrimExpr, int]): ... - -class uint64(PrimExpr): - def __init__(self: uint64, imm: Union[PrimExpr, int]): ... - -# use typing.Literal instead for python 3.8 or higher -import sys - -if sys.version_info >= (3, 8): - from typing import Literal - - SpecialFloatLiteral = Literal["inf", "-inf", "nan"] -else: - SpecialFloatLiteral = str - -class float8(PrimExpr): - def __init__(self: float8, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ... - -class float16(PrimExpr): - def __init__(self: float16, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ... - -class float32(PrimExpr): - def __init__(self: float32, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ... - -class float64(PrimExpr): - def __init__(self: float64, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ... diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 8da78a599c..3e66d05cc3 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -19,7 +19,7 @@ # pylint: disable=invalid-name from numbers import Integral as _Integral -from typing import List +from typing import List, Union import tvm._ffi import tvm.arith._ffi_api @@ -560,13 +560,17 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None): return tvm.tir.IterVar(dom, name, 2, thread_tag, span) -def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc: +def create_prim_func( + ops: List[_tensor.Tensor], tir_var_list: List[tvm.tir.Var] = None +) -> tvm.tir.PrimFunc: """Create a TensorIR PrimFunc from tensor expression Parameters ---------- ops : List[Tensor] The source expression. + tir_var_list: List[Var] + TIR variables to add as parameters to generated PrimFunc Example ------- @@ -612,4 +616,24 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: """ if not isinstance(ops, (list, tuple, Array)): ops = [ops] - return _ffi_api.CreatePrimFunc(ops) + return _ffi_api.CreatePrimFunc(ops, tir_var_list) + + +def create_prim_func_from_outputs( + outputs: Union[_tensor.Tensor, List[_tensor.Tensor]], +) -> tvm.tir.PrimFunc: + """Create a TensorIR PrimFunc from output tensor(s) in TE + + Parameters + ---------- + outputs : Union[Tensor, List[Tensor]] + The source expression. + + Returns + ------- + func : tir.PrimFunc + The created function. + """ + if not isinstance(outputs, (list, tuple, Array)): + outputs = [outputs] + return _ffi_api.CreatePrimFuncFromOutputs(outputs) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 4628ae3626..5bc8d4e0f7 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -199,6 +199,7 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: def show(self, style: Optional[str] = None) -> None: """ A sugar for print highlighted TVM script. + Parameters ---------- style : str, optional diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index 9450ade34e..94d211a7fb 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -70,6 +70,7 @@ def prim_func_pass( opt_level: int = None, name: Optional[str] = None, required: Optional[List[str]] = None, + traceable=False, ) -> Union[Callable, PrimFuncPass]: """Decorate a function pass. @@ -148,7 +149,7 @@ def transform(func, mod, ctx): def create_function_pass(pass_arg): """Internal function that creates a function pass""" fname = name if name else pass_arg.__name__ - info = PassInfo(opt_level, fname, required) + info = PassInfo(opt_level, fname, required, traceable) if inspect.isclass(pass_arg): return _wrap_class_function_pass(pass_arg, info) if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): diff --git a/src/ir/function.cc b/src/ir/function.cc index c0cda704c4..500d94d11c 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -28,6 +28,7 @@ // and are only used in minimum cases where they are clearly marked. // // Rationale: We calls into the type specific WithAttr function +#include #include #include @@ -43,6 +44,36 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") return WithAttr(Downcast(std::move(func)), key, value); } else if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + return func; + } + }); + +TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") + .set_body_typed([](BaseFunc func, Map attr_map) -> BaseFunc { + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } else if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } else if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + return func; + } + }); + +TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") + .set_body_typed([](BaseFunc func, String key) -> BaseFunc { + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); + } else if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); } else { LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); return func; diff --git a/src/ir/module.cc b/src/ir/module.cc index 8d6de5a536..3de19cfa9d 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -75,22 +75,32 @@ IRModule::IRModule(tvm::Map functions, } bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { - if (functions.size() != other->functions.size()) return false; if (!equal(this->attrs, other->attrs)) return false; + + if (functions.size() != other->functions.size()) return false; + // Update GlobalVar remap + for (const auto& gv : this->GetGlobalVars()) { + if (!other->ContainGlobalVar(gv->name_hint)) return false; + if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false; + } for (const auto& kv : this->functions) { - if (!other->ContainGlobalVar(kv.first->name_hint)) return false; if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; } + if (type_definitions.size() != other->type_definitions.size()) return false; + // Update GlobalTypeVar remap + for (const auto& gtv : this->GetGlobalTypeVars()) { + if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false; + if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false; + } for (const auto& kv : this->type_definitions) { - if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false; if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false; } return true; } void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { - using KV = std::pair; + using KV = std::pair>; // hash the functions. std::vector temp; @@ -100,21 +110,24 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); hash_reduce(static_cast(temp.size())); + // Defhash the GlobalVar/GlobalTypeVar + for (size_t i = 0; i < temp.size(); ++i) { + hash_reduce.DefHash(temp[i].second.first); + } // hash the content for (size_t i = 0; i < temp.size(); ++i) { - hash_reduce(temp[i].first); - hash_reduce(temp[i].second); + hash_reduce(temp[i].second.second); } }; for (const auto& kv : this->functions) { - temp.emplace_back(kv.first->name_hint, kv.second); + temp.emplace_back(kv.first->name_hint, std::make_pair(kv.first, kv.second)); } reduce_temp(); temp.clear(); for (const auto& kv : this->type_definitions) { - temp.emplace_back(kv.first->name_hint, kv.second); + temp.emplace_back(kv.first->name_hint, std::make_pair(kv.first, kv.second)); } reduce_temp(); hash_reduce(this->attrs); @@ -436,10 +449,8 @@ IRModule IRModule::FromText(const String& text, const String& source_path) { TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") - .set_body_typed([](tvm::Map funcs, - tvm::Map types) { - return IRModule(funcs, types, {}); - }); + .set_body_typed([](tvm::Map funcs, tvm::Map types, + tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); }); TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) { IRModule mod = args[0]; @@ -527,10 +538,22 @@ TVM_REGISTER_GLOBAL("ir.Module_WithAttr") return WithAttr(mod, key, value); }); +TVM_REGISTER_GLOBAL("ir.Module_WithoutAttr") + .set_body_typed([](IRModule mod, String key) -> IRModule { return WithoutAttr(mod, key); }); + +TVM_REGISTER_GLOBAL("ir.Module_WithAttrs") + .set_body_typed([](IRModule mod, Map attr_map) -> IRModule { + return WithAttrs(mod, attr_map); + }); + TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String key) -> ObjectRef { return mod->GetAttr(key); }); +TVM_REGISTER_GLOBAL("ir.Module_GetAttrs").set_body_typed([](IRModule mod) -> ObjectRef { + return mod->GetAttrs(); +}); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 77ea942a0b..030453c60f 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -341,11 +342,13 @@ class ModulePass : public Pass { TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); }; -PassInfo::PassInfo(int opt_level, String name, tvm::Array required) { +PassInfo::PassInfo(int opt_level, String name, tvm::Array required, + bool traceable) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); pass_info->required = std::move(required); + pass_info->traceable = std::move(traceable); data_ = std::move(pass_info); } @@ -391,6 +394,7 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c VLOG(1) << "Result module:" << std::endl << PrettyPrint(mod); + // TODO(@lesheng): will need to check if the updated IRModule is well formed return mod; } @@ -404,7 +408,7 @@ Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { Sequential::Sequential(tvm::Array passes, String name) { auto n = make_object(); n->passes = std::move(passes); - PassInfo pass_info = PassInfo(0, std::move(name), {}); + PassInfo pass_info = PassInfo(0, std::move(name), {}, /* traceable */ false); n->pass_info = std::move(pass_info); data_ = std::move(n); } @@ -446,26 +450,61 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c VLOG(0) << "skipping disabled pass '" << pass_info->name << "'"; continue; } + // resolve dependencies for (const auto& it : pass_info->required) { mod = GetPass(it)(std::move(mod), pass_ctx); } - mod = pass(std::move(mod), pass_ctx); + + // This handles passes that does not use Relax tuning API (untraceable passes). + // We make untraceable passes trackable when pass context has a trace (trace mode). + // When passes to trace (make_traceable) is provided from users, we only make them trackable. + if (pass_ctx->trace_stack.size() && !pass_info->traceable && + (!pass_ctx->make_traceable.defined() || + pass_ctx->make_traceable.value().count(pass_info->name))) { + // TODO(tvm-team): Currently, there are some inconsistency in the pass registration. + // 1. Some passes are not registered in ffi registry. + // 2. Some passes do not follow the name convention. (e.g., = + ) + + // Due to these problems, serialization with non-traceable passes is handled in a hacky way + // now. Find a systematic way to identify such inconsistencies and fix them. + + // In the future, we should pass the ffi key for a pass by deducing from its name. + String transform_func_key = "relax.tuning_api.Choice.default_transform_func"; + String constr_func_key = "relax.tuning_api.Choice.default_constr_func"; + + relax::Knob knob = relax::Knob( + pass_info->name, {{"Applied", relax::Choice(transform_func_key, Array(), + constr_func_key, Array())}}); + + // Add new decision to the trace at the top of the stack. + auto trace = Downcast(pass_ctx->trace_stack.back()); + trace->Add(knob, "Applied"); + // In the future, we should just have + // mod = trace->Add(knob, "enabled"); + // instead of the two lines below. + mod = pass(std::move(mod), pass_ctx); + trace->SetOutMod(mod); + + } else { + mod = pass(std::move(mod), pass_ctx); + } } return mod; } Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return ModulePass(pass_func, pass_info); } TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") - .set_body_typed([](int opt_level, String name, tvm::Array required) { - return PassInfo(opt_level, name, required); + .set_body_typed([](int opt_level, String name, tvm::Array required, bool traceable) { + return PassInfo(opt_level, name, required, traceable); }); TVM_REGISTER_GLOBAL("transform.Info").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -516,7 +555,8 @@ TVM_REGISTER_GLOBAL("transform.Sequential").set_body([](TVMArgs args, TVMRetValu int opt_level = args[1]; std::string name = args[2]; tvm::Array required = args[3]; - PassInfo pass_info = PassInfo(opt_level, name, required); + bool traceable = args[4]; + PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable); *ret = Sequential(passes, pass_info); }); @@ -539,7 +579,9 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, Array instruments, - Optional> config) { + Optional> config, Array trace_stack, + Optional> make_traceable, int num_evals, + Optional tuning_api_database) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; @@ -549,6 +591,10 @@ TVM_REGISTER_GLOBAL("transform.PassContext") if (config.defined()) { pctx->config = config.value(); } + pctx->trace_stack = std::move(trace_stack); + pctx->make_traceable = std::move(make_traceable); + pctx->num_evals = std::move(num_evals); + pctx->tuning_api_database = std::move(tuning_api_database); PassConfigManager::Global()->Legalize(&(pctx->config)); return pctx; }); @@ -564,7 +610,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "\tdisabled passes: " << node->disabled_pass << "\n"; p->stream << "\tinstruments: " << node->instruments << "\n"; - p->stream << "\tconfig: " << node->config; + p->stream << "\tconfig: " << node->config << "\n"; + p->stream << "\ttrace stack: " << node->trace_stack; }); class PassContext::Internal { @@ -574,6 +621,22 @@ class PassContext::Internal { static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; +TVM_REGISTER_GLOBAL("transform.GetTraceStack") + .set_body_method(&PassContextNode::GetTraceStack); +TVM_REGISTER_GLOBAL("transform.PushTrace") + .set_body_method(&PassContextNode::PushTrace); +TVM_REGISTER_GLOBAL("transform.PopTrace").set_body_method(&PassContextNode::PopTrace); +TVM_REGISTER_GLOBAL("transform.GetTraceStackSize") + .set_body_method(&PassContextNode::GetTraceStackSize); +TVM_REGISTER_GLOBAL("transform.GetCurrentTrace") + .set_body_method(&PassContextNode::GetCurrentTrace); +TVM_REGISTER_GLOBAL("transform.SetNumEvals") + .set_body_method(&PassContextNode::SetNumEvals); +TVM_REGISTER_GLOBAL("transform.IncNumEvals") + .set_body_method(&PassContextNode::IncNumEvals); +TVM_REGISTER_GLOBAL("transform.GetTuningAPIDatabase") + .set_body_method(&PassContextNode::GetTuningAPIDatabase); + TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope); @@ -592,7 +655,7 @@ Pass PrintIR(String header, bool show_meta_data) { LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data); return mod; }; - return CreateModulePass(pass_func, 0, "PrintIR", {}); + return CreateModulePass(pass_func, 0, "PrintIR", {}, /* traceable */ false); } TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); diff --git a/src/ir/type.cc b/src/ir/type.cc index fe8e00329b..86dda2a274 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -144,8 +144,8 @@ TupleType TupleType::Empty() { return TupleType(Array()); } TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { - return TupleType(fields); +TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields, Span span) { + return TupleType(fields, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc new file mode 100644 index 0000000000..56e358d4b7 --- /dev/null +++ b/src/printer/relax_script_printer.cc @@ -0,0 +1,668 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file printer/relax_script_printer.cc + * \brief Printer class to print Relax IR to parsable Python + */ + +#include +#include +#include + +#include +#include + +#include "doc.h" +#include "text_printer.h" + +namespace tvm { +namespace relax { + +Doc PrintDType(DataType dtype) { return Doc::StrLiteral(runtime::DLDataType2String(dtype)); } + +Doc RelaxScriptPrinter::Print(const ObjectRef& node) { + if (node->IsInstance()) { + return PrintIRModule(Downcast(node)); + } else if (node->IsInstance()) { + return VisitType(Downcast(node)); + } else if (node->IsInstance()) { + return VisitExpr(Downcast(node)); + } else if (node->IsInstance()) { + return tir::AsTVMScriptDoc(Downcast(node), "T", false); + } else { + return VisitNode(node); + } +} + +Doc RelaxScriptPrinter::VisitNode_(const relay::TupleNode* op) { + size_t num_fields = op->fields.size(); + + if (num_fields == 0) { + return Doc::Text("tuple()"); + } + + Doc doc; + std::vector fields; + + for (const Expr& field : op->fields) { + fields.push_back(Print(field)); + } + doc << "(" << Doc::Concat(fields, Doc::Text(", ")); + if (num_fields == 1) { + doc << ","; + } + doc << ")"; + + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relay::GlobalVarNode* op) { + return Doc::Text(op->name_hint); +} + +Doc RelaxScriptPrinter::VisitNode_(const relay::CallNode* op) { + // TODO(@altanh): how to support when func cannot be printed as Python expr? + // e.g. Function or If + Doc doc; + std::vector args; + + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (op->op == call_tir_op) { + doc << "R.call_tir"; + + for (const Expr& arg : op->args) { + args.push_back(Print(arg)); + } + doc << "(" << Doc::Concat(args, Doc::Text(", ")); + + Type output_type = op->type_args[0]; + if (const auto* out_type = output_type.as()) { + doc << ", dtype=" << PrintDType(out_type->dtype) << ")"; + } else if (const auto* out_type = output_type.as()) { + std::vector dtypes; + for (auto field : out_type->fields) { + if (const auto* field_type = field.as()) { + Doc dtype; + dtype << PrintDType(field_type->dtype); + dtypes.push_back(dtype); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << field_type->GetTypeKey(); + } + } + doc << ", dtype=(" << Doc::Concat(dtypes, Doc::Text(", ")) << "))"; + } else { + LOG(FATAL) << "TypeError: Invalid type: " << output_type->GetTypeKey(); + } + return doc; + } + + if (op->op.as()) { + doc << "R.call_packed"; + args.push_back(Print(op->op)); + } else { + doc << Print(op->op); + } + + for (const Expr& arg : op->args) { + args.push_back(Print(arg)); + } + doc << "(" << Doc::Concat(args, Doc::Text(", ")); + + std::vector attrs = PrintAttrs(op->attrs); + if (op->attrs.defined()) { + attrs.push_back(Doc::Text("attrs_type_key=") << Doc::StrLiteral(op->attrs->GetTypeKey())); + } + if (!attrs.empty()) { + doc << ", " << Doc::Concat(attrs); + } + + if (!op->type_args.empty()) { + doc << ", type_args=("; + std::vector type_args = PrintTypeArgs(op->type_args); + doc << Doc::Concat(type_args); + doc << ")"; + } + + doc << ")"; + + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const OpNode* op) { return Doc::Text(op->name); } + +Doc RelaxScriptPrinter::VisitNode_(const relay::TupleGetItemNode* op) { + Doc doc; + doc << Print(op->tuple) << "[" << op->index << "]"; + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::VarNode* op) { + if (!var_id_map_.count(op->vid)) { + var_id_map_[op->vid] = GetUniqueName(op->name_hint(), "v"); + } + + return var_id_map_[op->vid]; +} + +/*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param value The value to be printed. + */ +template +Doc ScalarLiteral(DataType dtype, const T& value) { + std::ostringstream os; + if (dtype == DataType::Bool()) { + return Doc::PyBoolLiteral(value != 0); + } else { + os << value; + } + return Doc::Text(os.str()); +} + +// Overload of Expr printing functions +Doc RelaxScriptPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, + bool optional_info) { + Doc printed_expr; + if (meta) { + printed_expr = meta_->GetMetaNode(GetRef(expr.get())); + } else { + printed_expr = VisitNode(expr); + } + return printed_expr; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::ConstantNode* op) { + Doc doc; + // Print out simple scalars directly. + if (op->data->ndim == 0) { + std::ostringstream os; + DataType dtype = DataType(op->data->dtype); + ICHECK_EQ(op->data->device.device_type, kDLCPU); + auto scalar_val = ScalarLiteral(dtype, 0); + if (dtype == DataType::Int(32)) { + scalar_val = ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } else if (dtype == DataType::Int(64)) { + scalar_val = ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } else if (dtype == DataType::Float(32)) { + scalar_val = ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } else if (dtype == DataType::Float(64)) { + scalar_val = ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } else if (dtype == DataType::Bool()) { + scalar_val = ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } + return doc << scalar_val; + } + // default fall-back, record it as meta node. + // Don't append optional_info. Because the entry function is Print, + // and it will append the optional_info afterwards. + return doc << PrintExpr(GetRef(op), true, false, false); +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowVarNode* op) { + if (!var_id_map_.count(op->vid)) { + var_id_map_[op->vid] = GetUniqueName(op->name_hint(), "dv"); + } + + return var_id_map_[op->vid]; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::ShapeExprNode* op) { + // TODO(@altanh): support more PrimExpr printing, and check that empty tuple + // is never ambiguously printed as "()" + Doc doc; + + std::vector fields; + for (const PrimExpr& field : op->values) { + fields.push_back(Print(field)); + } + doc << "(" << Doc::Concat(fields, Doc::Text(", ")); + if (fields.size() == 1) { + doc << ","; + } + doc << ")"; + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::RuntimeDepShapeNode* op) { + Doc doc; + + doc << "_"; + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::MatchShapeNode* op) { + Doc doc; + if (op->var.defined()) { + doc << Print(op->var) << PrintVarAnnotation(op->var) << " = "; + } + doc << "R.match_shape("; + // TODO(@altanh): maybe op->pattern should just be a ShapeExpr? + doc << Print(op->value) << ", " << Print(relax::ShapeExpr(op->pattern)); + doc << ")"; + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::VarBindingNode* op) { + // TODO(@altanh): think deeper about normal form (need to be strict about block exprs) + if (const relay::IfNode* ite = op->value.as()) { + return PrintIfStmt(op->var, GetRef(ite)); + } else if (const relax::FunctionNode* func = op->value.as()) { + return PrintFunctionDef(Print(op->var), GetRef(func)); + } else if (const tir::PrimFuncNode* prim_func = op->value.as()) { + return PrintPrimFunc(op->var->name_hint(), GetRef(prim_func)); + } else { + Doc doc; + bool print_annotation = true; + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (const CallNode* value = op->value.as()) { + if (value->op == call_tir_op) { + print_annotation = false; + } + } + doc << Print(op->var); + if (print_annotation) { + doc << PrintVarAnnotation(op->var); + } + doc << " = " << Print(op->value); + return doc; + } +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::BindingBlockNode* op) { + Doc doc; + for (const relax::Binding& binding : op->bindings) { + doc << Print(binding) << Doc::NewLine(); + } + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowBlockNode* op) { + Doc block; + Doc body; + std::vector return_vars; + for (const relax::Binding& binding : op->bindings) { + body << Print(binding) << Doc::NewLine(); + Var var; + if (const relax::VarBindingNode* var_binding = binding.as()) { + var = var_binding->var; + } else if (const relax::MatchShapeNode* shape_binding = binding.as()) { + var = shape_binding->var; + } + if (var.defined() && !var.as()) { + return_vars.push_back(Print(var)); + } + } + ICHECK(!return_vars.empty()) << "dataflow blocks should have at least one output variable"; + body << "R.output(" << Doc::Concat(return_vars, Doc::Text(", ")) << ")"; + block << "with R.dataflow():" << Doc::NewLine(4); + block << Doc::Indent(4, body) << Doc::NewLine(); + return block; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::SeqExprNode* op) { + Doc doc; + int i = 0; + for (const relax::BindingBlock& block : op->blocks) { + doc << "# block " << i++ << Doc::NewLine(); + doc << Print(block); + } + // NOTE: the body expression is printed in the parent, since SeqExprs are used for both Function + // bodies and If expr bodies (which don't have a "return" statement but instead a binding) + return doc; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::FunctionNode* op) { + Optional gsymbol = op->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()); + return PrintFunctionDef(Doc::Text(gsymbol.value()), GetRef(op)); +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::ExternFuncNode* op) { + return Doc::StrLiteral(op->global_symbol); +} + +Doc RelaxScriptPrinter::VisitExpr_(const tir::VarNode* op) { + tir::Var var = GetRef(op); + if (!dim_var_map_.count(var)) { + dim_var_map_[var] = GetUniqueName(var->name_hint, "dim"); + } + return dim_var_map_[var]; +} + +Doc RelaxScriptPrinter::VisitExpr_(const tir::IntImmNode* op) { + return Doc::Text(std::to_string(op->value)); +} + +#define TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(OpName, OpString) \ + Doc RelaxScriptPrinter::VisitExpr_(const OpName* op) { \ + Doc doc; \ + doc << "(" << Print(op->a) << OpString; \ + doc << Print(op->b) << ")"; \ + return doc; \ + } + +TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(tir::AddNode, " + ") +TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(tir::SubNode, " - ") +TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(tir::MulNode, " * ") +TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(tir::DivNode, " / ") +TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(tir::FloorDivNode, " // ") + +Doc RelaxScriptPrinter::VisitExpr_(const tir::CastNode* op) { + Doc doc; + doc << "T.cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")"; + return doc; +} + +Doc RelaxScriptPrinter::VisitExpr_(const tir::MaxNode* op) { + Doc doc; + doc << "T.max(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc RelaxScriptPrinter::VisitType_(const relax::ShapeTypeNode* node) { return Doc::Text("Shape"); } + +Doc RelaxScriptPrinter::VisitType_(const relax::ObjectTypeNode* node) { + return Doc::Text("Object"); +} + +Doc RelaxScriptPrinter::VisitType_(const relax::DynTensorTypeNode* node) { + // NOTE: to print shape information, use PrintTensorAnnotation + return PrintTensorAnnotation(GetRef(node), NullOpt); +} + +Doc RelaxScriptPrinter::VisitType_(const relay::TupleTypeNode* node) { + if (node->fields.empty()) { + return Doc::Text("Tuple()"); + } + + Doc doc; + + std::vector fields; + for (Type ty : node->fields) { + fields.push_back(Print(ty)); + } + doc << "Tuple(" << Doc::Concat(fields) << ")"; + + return doc; +} + +Doc RelaxScriptPrinter::VisitType_(const relay::FuncTypeNode* node) { + Doc doc; + doc << "Callable"; + if (node->type_params.size() != 0) { + doc << "("; + std::vector type_params; + for (Type type_param : node->type_params) { + type_params.push_back(Print(type_param)); + } + doc << Doc::Concat(type_params); + doc << ")"; + } + std::vector arg_types; + for (Type arg_type : node->arg_types) { + arg_types.push_back(Print(arg_type)); + } + // TODO(@yongwww): Change it to Callable[[Arg1Type, Arg2Type, ...,], ReturnType] + // to be consistent with Python type hint syntax, + return doc << "((" << Doc::Concat(arg_types) << "), " << Print(node->ret_type) << ")"; +} + +Doc RelaxScriptPrinter::PrintAttr(const ObjectRef& attr) { + if (attr.defined()) { + if (const StringObj* str = attr.as()) { + return Doc::StrLiteral(GetRef(str)); + } else { + return VisitAttr(attr); + } + } else { + return Doc::Text("None"); + } +} + +std::vector RelaxScriptPrinter::PrintAttrs(const Attrs& attrs) { + std::vector kwargs; + if (!attrs.defined()) { + return kwargs; + } else if (const DictAttrsNode* dict_attrs = attrs.as()) { + for (const auto& k : dict_attrs->dict) { + kwargs.push_back(Doc::Text(k.first) << "=" << Print(k.second)); + } + } else { + AttrPrinter attr_printer(&kwargs, this); + const_cast(attrs.operator->())->VisitAttrs(&attr_printer); + } + return kwargs; +} + +std::vector RelaxScriptPrinter::PrintTypeArgs(const Array& type_args) { + std::vector type_args_doc; + if (!type_args.empty()) { + for (const auto& type : type_args) { + if (const auto* tensor = type.as()) { + Doc doc; + doc << "Tensor(ndim=" << tensor->ndim << ", dtype=" << PrintDType(tensor->dtype) << ")"; + type_args_doc.push_back(doc); + } else { + type_args_doc.push_back(this->VisitType(type)); + } + } + } + return type_args_doc; +} + +Doc RelaxScriptPrinter::VisitAttrDefault_(const Object* op) { + return PrintAttr(GetRef(op)); +} + +Doc RelaxScriptPrinter::VisitAttr_(const ArrayNode* op) { + Doc doc; + std::vector arr_vals; + for (ObjectRef val : *op) { + arr_vals.push_back(PrintAttr(val)); + } + doc << "[" << Doc::Concat(arr_vals) << "]"; + return doc; +} + +Doc RelaxScriptPrinter::VisitAttr_(const tir::IntImmNode* op) { + return Doc::Text(std::to_string(op->value)); +} + +Doc RelaxScriptPrinter::VisitAttr_(const tir::FloatImmNode* op) { + return Doc::Text(std::to_string(op->value)); +} + +Doc RelaxScriptPrinter::PrintIRModule(const IRModule& mod) { + Doc doc; + if (ShowMetaData()) { + doc << "@tvm.script.ir_module(metadata=metadata)" << Doc::NewLine(); + } else { + doc << "@tvm.script.ir_module" << Doc::NewLine(); + } + doc << "class Module:"; + for (const std::pair& pr : mod->functions) { + Doc func; + if (pr.second.as()) { + func = PrintPrimFunc(pr.first->name_hint, Downcast(pr.second)); + } else { + func = Print(pr.second); + } + doc << Doc::Indent(4, Doc::NewLine() << func); + } + return doc; +} + +Doc RelaxScriptPrinter::PrintPrimFunc(const String& name, const tir::PrimFunc& func) { + // we need the mod for TVMScriptPrinter to properly print the function name - maybe it's worth + // refactoring to avoid this? + IRModule mod; + mod->Add(relay::GlobalVar(name), func); + return tir::AsTVMScriptDoc(mod, "T", false, func); +} + +Doc RelaxScriptPrinter::PrintIfStmt(const relax::Var& var, const relay::If& ite) { + const relax::SeqExprNode* true_branch = ite->true_branch.as(); + const relax::SeqExprNode* false_branch = ite->false_branch.as(); + // TODO(@altanh): this invariant must be maintained by the normal form + ICHECK(true_branch && false_branch) + << "in the Relax IR normal form, each branch of a If expression should be a SeqExpr"; + + Doc doc; + doc << "if " << Print(ite->cond) << ":" << Doc::NewLine(4); + doc << Doc::Indent(4, Print(GetRef(true_branch))); + doc << Doc::Indent(4, Print(relax::VarBinding(var, true_branch->body))); + doc << Doc::NewLine(); + doc << "else:" << Doc::NewLine(4); + doc << Doc::Indent(4, Print(GetRef(false_branch))); + doc << Doc::Indent(4, Print(relax::VarBinding(var, false_branch->body))); + return doc; +} + +Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& func) { + Doc doc; + + std::vector params; + for (size_t i = 0; i < func->params.size(); ++i) { + relax::Var var = func->params[i]; + Doc param; + param << Print(var) << PrintVarAnnotation(var); + params.push_back(param); + } + if (ShowMetaData()) { + doc << "@R.function(metadata=metadata)" << Doc::NewLine(); + } else { + doc << "@R.function" << Doc::NewLine(); + } + doc << "def " << name << "(" << Doc::Concat(params, Doc::Text(", ")) << ")"; + if (func->ret_type.defined()) { + doc << " -> " << Print(func->ret_type); + } + doc << ":" << Doc::NewLine(4); + + if (const relax::SeqExprNode* body = func->body.as()) { + doc << Doc::Indent(4, Print(func->body)); + doc << Doc::Indent(4, Doc::Text("return ") << Print(body->body)) << Doc::NewLine(); + } else if (const relax::FunctionNode* body = func->body.as()) { + // nested function + String func_name; + Optional gsymbol = body->GetAttr(tvm::attr::kGlobalSymbol); + if (gsymbol.defined()) { + func_name = gsymbol.value(); + } else { + func_name = "local_func_" + std::to_string(local_func_counter_++); + } + Doc nested_func = PrintFunctionDef(Doc::Text(func_name), GetRef(body)); + doc << Doc::Indent(4, nested_func); + doc << Doc::Indent(4, Doc::Text("return ") << func_name) << Doc::NewLine(); + } else { + doc << Doc::Indent(4, Doc::Text("return ") << Print(func->body)) << Doc::NewLine(); + } + + return doc; +} + +Doc RelaxScriptPrinter::PrintVarAnnotation(const relax::Var& var) { + // TODO(@altanh): we should consider moving annotation into binding + Doc doc; + Type annotation = var->checked_type_; + if (annotation.defined()) { + doc << ": "; + if (const relax::DynTensorTypeNode* tty = annotation.as()) { + doc << PrintTensorAnnotation(GetRef(tty), var->shape_); + } else if (const TupleTypeNode* tty = annotation.as()) { + doc << PrintTupleAnnotation(GetRef(tty), var->shape_); + } else { + doc << Print(annotation); + } + } + return doc; +} + +Doc RelaxScriptPrinter::PrintTensorAnnotation(const relax::DynTensorType& ty, + const Optional& shape) { + Doc doc; + doc << "Tensor("; + // Print shape annotation + if (shape.defined()) { + doc << Print(Downcast(shape.value())); + } else { + doc << "None"; + } + // Print dtype annotation + doc << ", "; + if (ty->dtype.is_void()) { + doc << "_"; + } else { + doc << PrintDType(ty->dtype); + } + // Print ndim annotation only when it cannot be inferred from shape itself. + if (!shape.defined() || shape->IsInstance()) { + doc << ", ndim = " << ty->ndim; + } + doc << ")"; + return doc; +} + +Doc RelaxScriptPrinter::PrintTupleAnnotation(const TupleType& ty, + const Optional& shape) { + Doc doc; + doc << "Tuple"; + std::vector fields; + for (size_t i = 0; i < ty->fields.size(); i++) { + if (shape) { + if (const TupleNode* shape_tuple = shape.value().as()) { + if (const DynTensorTypeNode* type_field = ty->fields[i].as()) { + fields.push_back( + PrintTensorAnnotation(GetRef(type_field), shape_tuple->fields[i])); + } + } + } else { + if (const DynTensorTypeNode* type_field = ty->fields[i].as()) { + fields.push_back(PrintTensorAnnotation(GetRef(type_field), NullOpt)); + } + } + } + doc << "(" << Doc::Concat(fields, Doc::Text(", ")) << ")"; + return doc; +} + +Doc RelaxScriptPrinter::GetUniqueName(std::string prefix, std::string fallback = "x") { + if (prefix.empty()) { + prefix = fallback; + } + return Doc::Text(name_table_.GetUniqueName(prefix)); +} + +bool RelaxScriptPrinter::ShowMetaData() { return show_meta_data_; } + +String AsRelaxScript(const ObjectRef& mod, bool show_meta_data) { + ICHECK(mod->IsInstance() || mod->IsInstance() || + mod->IsInstance()); + Doc doc; + runtime::TypedPackedFunc ftyped = nullptr; + doc << TextPrinter(show_meta_data, ftyped).PrintRelax(mod); + return doc.str(); +} + +TVM_REGISTER_GLOBAL("script.AsRelaxScript").set_body_typed(AsRelaxScript); + +} // namespace relax +} // namespace tvm diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 76cac28b07..84e2f88112 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -755,6 +755,18 @@ Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) { return doc; } +Doc RelayTextPrinter::VisitType_(const relax::DynTensorTypeNode* node) { + Doc doc; + doc << "Tensor[ndim=" << node->ndim << ", dtype=\"" << PrintDType(node->dtype) << "\"]"; + return doc; +} + +Doc RelayTextPrinter::VisitType_(const relax::ObjectTypeNode* node) { + Doc doc; + doc << "Object"; + return doc; +} + //------------------------------------ // Overload of Attr printing functions //------------------------------------ diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index 4d4113fef6..29de9a33a3 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -81,6 +81,8 @@ Doc TextPrinter::PrintMod(const IRModule& mod) { } else if (base_func.as()) { doc << "@" << var->name_hint; doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast(base_func)); + } else if (base_func.as()) { + doc << relax_text_printer_.Print(base_func); } doc << Doc::NewLine(); } diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 2dc0997f82..ac5d71efeb 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -28,6 +28,8 @@ #include #include +#include +#include #include #include #include @@ -185,6 +187,8 @@ class RelayTextPrinter : public ExprFunctor, Doc VisitType_(const FuncTypeNode* node) final; Doc VisitType_(const RelayRefTypeNode* node) final; Doc VisitType_(const TypeDataNode* node) final; + Doc VisitType_(const relax::DynTensorTypeNode* node) final; + Doc VisitType_(const relax::ObjectTypeNode* node) final; //------------------------------------ // Overload of Attr printing functions //------------------------------------ @@ -233,6 +237,133 @@ class RelayTextPrinter : public ExprFunctor, } // namespace relay } // namespace tvm +namespace tvm { +namespace relax { +class RelaxScriptPrinter : public relax::IRFunctor, + public tir::ExprFunctor, + public TypeFunctor, + public AttrFunctor { + public: + explicit RelaxScriptPrinter(bool show_meta_data, TextMetaDataContext* meta) + : show_meta_data_(show_meta_data), meta_(meta) {} + TVM_DLL Doc Print(const ObjectRef& node); + bool ShowMetaData(); + + private: + NameTable name_table_; + /*! \brief Whether to print meta data. */ + bool show_meta_data_; + /*! \brief A counter for naming local functions. */ + size_t local_func_counter_ = 0; + /*! \brief meta data context */ + TextMetaDataContext* meta_; + std::unordered_map var_id_map_; + std::unordered_map dim_var_map_; + + // IR nodes inherited from Relay + Doc VisitNode_(const relay::TupleNode* op) override; + Doc VisitNode_(const relay::GlobalVarNode* op) override; + Doc VisitNode_(const relay::ConstantNode* op) override; + Doc VisitNode_(const relay::CallNode* op) override; + // Doc VisitNode_(const relay::IfNode* op) override; + Doc VisitNode_(const OpNode* op) override; + Doc VisitNode_(const relay::TupleGetItemNode* op) override; + + // IR nodes introduced by Relax + Doc VisitNode_(const relax::VarNode* op) override; + Doc VisitNode_(const relax::DataflowVarNode* op) override; + Doc VisitNode_(const relax::ShapeExprNode* op) override; + Doc VisitNode_(const relax::RuntimeDepShapeNode* op) override; + Doc VisitNode_(const relax::MatchShapeNode* op) override; + Doc VisitNode_(const relax::VarBindingNode* op) override; + Doc VisitNode_(const relax::BindingBlockNode* op) override; + Doc VisitNode_(const relax::DataflowBlockNode* op) override; + Doc VisitNode_(const relax::SeqExprNode* op) override; + Doc VisitNode_(const relax::FunctionNode* op) override; + Doc VisitNode_(const relax::ExternFuncNode* op) override; + + // PrimExpr nodes allowed in Relax + Doc VisitExpr_(const tir::VarNode* op) override; + Doc VisitExpr_(const tir::IntImmNode* op) override; + Doc VisitExpr_(const tir::AddNode* op) override; + Doc VisitExpr_(const tir::SubNode* op) override; + Doc VisitExpr_(const tir::MulNode* op) override; + Doc VisitExpr_(const tir::DivNode* op) override; + Doc VisitExpr_(const tir::FloorDivNode* op) override; + Doc VisitExpr_(const tir::CastNode* op) override; + Doc VisitExpr_(const tir::MaxNode* op) override; + + Doc PrintIRModule(const IRModule& mod); + Doc PrintPrimFunc(const String& name, const tir::PrimFunc& func); + + Doc PrintIfStmt(const relax::Var& var, const relay::If& ite); + Doc PrintFunctionDef(const Doc& name, const relax::Function& func); + + Doc PrintVarAnnotation(const relax::Var& var); + Doc PrintTensorAnnotation(const relax::DynTensorType& ty, const Optional& shape); + Doc PrintTupleAnnotation(const TupleType& ty, const Optional& shape); + + Doc VisitType_(const relax::ShapeTypeNode* node) override; + Doc VisitType_(const relax::ObjectTypeNode* node) override; + Doc VisitType_(const relax::DynTensorTypeNode* node) override; + Doc VisitType_(const relay::TupleTypeNode* node) override; + Doc VisitType_(const relay::FuncTypeNode* node) override; + + Doc PrintAttr(const ObjectRef& attr); + std::vector PrintAttrs(const Attrs& attrs); + std::vector PrintTypeArgs(const Array& type_args); + Doc VisitAttrDefault_(const Object* op) override; + Doc PrintExpr(const Expr& expr, bool meta, bool try_inline, bool optional_info = true); + Doc VisitAttr_(const ArrayNode* op) override; + Doc VisitAttr_(const tir::IntImmNode* op) override; + Doc VisitAttr_(const tir::FloatImmNode* op) override; + + Doc GetUniqueName(std::string prefix, std::string fallback); + + /*! + * \brief Attribute printer which prints the attributes as kwargs in a call. + */ + class AttrPrinter : public AttrVisitor { + public: + AttrPrinter(std::vector* docs, RelaxScriptPrinter* parent) : docs(docs), parent_(parent) {} + + template + void PrintKV(const char* key, const T& value) { + Doc doc; + doc << key << "=" << value; + docs->push_back(doc); + } + + void Visit(const char* key, double* value) final { PrintKV(key, *value); } + void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, int* value) final { PrintKV(key, *value); } + void Visit(const char* key, bool* value) final { PrintKV(key, Doc::PyBoolLiteral(*value)); } + void Visit(const char* key, std::string* value) final { PrintKV(key, Doc::StrLiteral(*value)); } + void Visit(const char* key, void** value) final { + LOG(FATAL) << "do not allow void as argument"; + } + void Visit(const char* key, DataType* value) final { + PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value))); + } + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "do not allow NDarray as argument"; + } + void Visit(const char* key, runtime::ObjectRef* obj) final { + PrintKV(key, parent_->PrintAttr(*obj)); + } + + private: + std::vector* docs; + RelaxScriptPrinter* parent_; + }; +}; + +String AsRelaxScript(const ObjectRef& mod, bool show_meta_data); + +} // namespace relax +} // namespace tvm + namespace tvm { namespace tir { @@ -342,7 +473,6 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitExpr_(const ShuffleNode* op) override; Doc VisitExpr_(const ReduceNode* op) override; Doc VisitExprDefault_(const Object* op) override; - Doc VisitStmt_(const LetStmtNode* op) override; Doc VisitStmt_(const AttrStmtNode* op) override; Doc VisitStmt_(const AssertStmtNode* op) override; @@ -367,8 +497,8 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitType_(const PointerTypeNode* node) override; Doc VisitType_(const TupleTypeNode* node) override; - Doc PrintIRModule(const IRModule& module); Doc PrintPrimFunc(const PrimFunc& primFunc); + Doc PrintIRModule(const IRModule& module); Doc PrintArray(const ArrayNode* op); Doc PrintIterVar(const IterVarNode* op); Doc PrintRange(const RangeNode* op); @@ -410,6 +540,9 @@ String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "T", bool sh String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, runtime::TypedPackedFunc annotate); +Doc AsTVMScriptDoc(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false, + const PrimFunc& func = PrimFunc()); + } // namespace tir } // namespace tvm @@ -424,6 +557,7 @@ class TextPrinter { show_warning_(show_warning), annotate_(annotate), relay_text_printer_(show_meta_data, &meta_, annotate), + relax_text_printer_(show_meta_data, &meta_), tir_text_printer_(show_meta_data, &meta_) {} /*! \brief whether show meta data */ @@ -438,6 +572,8 @@ class TextPrinter { runtime::TypedPackedFunc annotate_; /*! \brief Relay Text Printer */ relay::RelayTextPrinter relay_text_printer_; + /*! \brief Relax Text Printer */ + relax::RelaxScriptPrinter relax_text_printer_; /*! \brief TIR Text Printer */ tir::TIRTextPrinter tir_text_printer_; @@ -451,6 +587,11 @@ class TextPrinter { (node->IsInstance() || node->IsInstance() || node->IsInstance())) { doc << tir_text_printer_.Print(node); + } else if (node.defined() && + (node->IsInstance() || node->IsInstance() || + node->IsInstance() || + node->IsInstance())) { + doc << relax_text_printer_.Print(node); } else { doc << relay_text_printer_.PrintFinal(node); } @@ -468,6 +609,18 @@ class TextPrinter { return doc; } + Doc PrintRelax(const ObjectRef& node) { + relax_text_printer_.Print(node); + Doc doc; + if (show_meta_data_ && !meta_.empty()) { + doc << "metadata = "; + doc << meta_.GetMetaSection(); + doc << Doc::NewLine(); + } + doc << relax_text_printer_.Print(node); + return doc; + } + Doc PrintMod(const IRModule& mod); }; } // namespace tvm diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 936ac7580f..f9d1b70876 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -448,6 +449,12 @@ class TVMScriptPrinter : public StmtFunctor, } return header; } + + void UpdateFuncName(const IRModule& mod) { + for (const auto& x : mod->functions) { + func2var_[x.second.get()] = x.first; + } + } }; /*! @@ -747,6 +754,10 @@ Doc TVMScriptPrinter::Print(const ObjectRef& node) { return PrintCommReducer(node.as()); } else if (node->IsInstance()) { return PrintTarget(node.as()); + } else if (node->IsInstance()) { + Doc doc; + doc << "te.Tensor"; + return doc; } else { LOG(FATAL) << "Do not know how to print " << node->GetTypeKey(); return Doc(); @@ -1579,9 +1590,7 @@ Doc TVMScriptPrinter::PrintIRModule(const IRModule& module) { Doc doc; doc << "@tvm.script.ir_module" << Doc::NewLine(); doc << "class Module:"; - for (const auto& x : op->functions) { - func2var_[x.second.operator->()] = x.first; - } + UpdateFuncName(module); Doc body = Doc::NewLine(); std::vector functions; for (auto it = op->functions.begin(); it != op->functions.end(); ++it) { @@ -1983,6 +1992,14 @@ Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) { } String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { + // Temporary redirect possibly relax related printing to relax script + // TODO(tvm-team): make relax script printer handle all possible cases and + // make that as a default of TVMScript printer + if (mod->IsInstance() || mod->IsInstance()) { + // TODO(tvm-team) support tir_prefix in relax printer + return relax::AsRelaxScript(mod, show_meta); + } + ICHECK(mod->IsInstance() || mod->IsInstance()); Doc doc; doc << TVMScriptPrinter::PrintHeader(tir_prefix) @@ -1990,6 +2007,19 @@ String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_met return doc.str() + "\n"; } +Doc AsTVMScriptDoc(const ObjectRef& mod, const String& tir_prefix, bool show_meta, + const PrimFunc& func) { + ICHECK(mod->IsInstance() || mod->IsInstance()); + TVMScriptPrinter printer(tir_prefix, show_meta); + // TODO(altan, tqchen): change to the first argument only? + Doc doc; + if (mod->IsInstance()) { + printer.UpdateFuncName(Downcast(mod)); + } + doc << (func.defined() ? printer.Print(func) : printer.Print(mod)) << Doc::NewLine(); + return doc; +} + TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc new file mode 100644 index 0000000000..bcad70cec9 --- /dev/null +++ b/src/relax/analysis/analysis.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file analysis.cc + * + * \brief Analysis functions for Relax. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + +template +struct InsertionSet { + std::unordered_set set; + std::vector data; + void Insert(const T& t) { + if (set.count(t) == 0) { + set.insert(t); + data.push_back(t); + } + } +}; + +class VarVisitor : protected ExprVisitor { + public: + Array Free(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + if (bound_vars_.set.count(v) == 0) { + ret.push_back(v); + } + } + return ret; + } + + Array Collect() { + Array ret; + for (const auto& v : bound_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array Bound(const Expr& expr) { + this->VisitExpr(expr); + return Collect(); + } + + Array All(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array AllGlobalVars(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : global_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array CalledGlobalVars(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : called_global_vars_.data) { + ret.push_back(v); + } + return ret; + } + + void MarkBounded(const Var& v) { + bound_vars_.Insert(v); + vars_.Insert(v); + } + + void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef(var)); } + + void VisitExpr_(const FunctionNode* op) final { + for (const auto& param : op->params) { + MarkBounded(param); + } + VisitExpr(op->body); + } + void VisitExpr_(const GlobalVarNode* op) final { global_vars_.Insert(GetRef(op)); } + + void VisitExpr_(const CallNode* call_node) final { + VisitSpan(call_node->span); + VisitExpr(call_node->op); + + for (Type ty_arg : call_node->type_args) { + VisitType(ty_arg); + } + + for (Expr arg : call_node->args) { + VisitExpr(arg); + } + + if (call_node->shape_) { + VisitExpr(Downcast(call_node->shape_.value())); + } + + if (const GlobalVarNode* global_var_node = call_node->op.as()) { + called_global_vars_.Insert(GetRef(global_var_node)); + } + } + + void VisitBinding_(const VarBindingNode* binding) final { + MarkBounded(binding->var); + VisitExpr(binding->value); + VisitVarDef(binding->var); + } + + void VisitBinding_(const MatchShapeNode* binding) final { + if (binding->var.defined()) { + MarkBounded(binding->var); + } + ExprVisitor::VisitBinding_(binding); + } + + private: + InsertionSet vars_; + InsertionSet bound_vars_; + InsertionSet global_vars_; + InsertionSet called_global_vars_; +}; + +class DimVisitor : public tir::ExprVisitor { + public: + void VisitExpr_(const tir::VarNode* v) override { vars_.Insert(GetRef(v)); } + + InsertionSet vars_; +}; + +class ShapeVarVisitor : public ExprVisitor { + public: + void VisitExpr_(const ShapeExprNode* shape) override { + for (auto dim : shape->values) { + DimVisitor v; + v(dim); + for (auto v : v.vars_.data) { + shape_vars_.Insert(v); + } + } + } + + InsertionSet shape_vars_; +}; + +tvm::Array ShapeVars(const Expr& expr) { + ShapeVarVisitor s; + s.VisitExpr(expr); + Array ret; + for (auto v : s.shape_vars_.data) { + ret.push_back(v); + } + return ret; +} + +tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } + +tvm::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } + +tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } + +tvm::Array AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); } + +tvm::Array CalledGlobalVars(const Expr& expr) { + return VarVisitor().CalledGlobalVars(expr); +} + +TVM_REGISTER_GLOBAL("relax.analysis.shape_vars").set_body_typed(ShapeVars); + +TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); + +TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); + +TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); + +TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); + +TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/func_ret_shape.cc b/src/relax/analysis/func_ret_shape.cc new file mode 100644 index 0000000000..45f283e308 --- /dev/null +++ b/src/relax/analysis/func_ret_shape.cc @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + +Expr DeriveFuncRetShape(Array args, Expr body) { + std::unordered_set arg_shape_var_set; + for (auto v : args) { + if (const ExprNode* s = v->shape_.as()) { + Expr shape = GetRef(s); + Array arg_shape_vars = ShapeVars(shape); + for (auto v : arg_shape_vars) { + arg_shape_var_set.insert(v); + } + } + } + + if (const ExprNode* s = body->shape_.as()) { + Expr body_shape = GetRef(s); + Array body_shape_vars = ShapeVars(body_shape); + for (auto v : body_shape_vars) { + // if the body shape contains a free var, then we can't + // be more specific than RuntimeDepShape + if (arg_shape_var_set.count(v) == 0) { + return RuntimeDepShape(); + } + } + // all vars are defined in the args, so we can use the body shape + return body_shape; + } + return RuntimeDepShape(); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.derive_func_ret_shape")).set_body_typed(DeriveFuncRetShape); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc new file mode 100644 index 0000000000..a1ea34804b --- /dev/null +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using namespace tir; + +class PatternKindAnalyzer : public StmtExprVisitor { + public: + explicit PatternKindAnalyzer(const tir::PrimFunc& func) { + for (const tir::Var& param : func->params) { + param_buffers_.insert(func->buffer_map.Get(param).value()); + } + } + + private: + bool IsOutputBlock(const BlockNode* block) { + for (const BufferRegion& write_region : block->writes) { + if (param_buffers_.count(write_region->buffer)) { + return true; + } + } + return false; + } + + void VisitStmt_(const BufferStoreNode* op) final { + ICHECK(!store_.defined()); + store_ = GetRef(op); + StmtVisitor::VisitStmt_(op); + } + + void VisitExpr_(const BufferLoadNode* op) final { + loads_.push_back(GetRef(op)); + ExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BlockNode* op) final { + if (op->name_hint == "root") { + // Skip the root block + StmtVisitor::VisitStmt(op->body); + return; + } + + // Step 1. Clear loads and store + loads_.clear(); + store_ = NullOpt; + // Step 2. Visit block body. + StmtVisitor::VisitStmt(op->body); + BufferStore store = store_.value(); + + // Step 3. Checking load store indices pattern + relay::OpPatternKind index_pair_pattern = relay::kElemWise; + bool has_elem_wise = false; + for (const BufferLoad& load : loads_) { + // Since elemwise is stricter than broadcast and broadcast is stricter than injective, + // while the order amount enums: kElemWise < kBroadcast < kInjective. + // We can simpily use `std::max` to detect these three patterns. + // E.g Here is only one store node but two load nodes, like C[i, j] = A[i, j] + B[i] + // Buffer C and A are elemwise but C and B are broadcast. So the whole block follows + // broadcast pattern. + if (IsElemwisePattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kElemWise); + has_elem_wise = true; + } else if (IsBroadcastPattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kBroadcast); + } else if (IsInjectivePattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kInjective); + } else { + index_pair_pattern = relay::kOpaque; + break; + } + } + // If there is a index pair is kElemWise and others are kBroadcast, we regard it as kElemWise + // e.g. A[i, j] = B[i, j] + C[i] + if (index_pair_pattern == relay::kBroadcast && has_elem_wise) { + index_pair_pattern = relay::kElemWise; + } + // If the block index pattern is not opaque, update kind. + if (index_pair_pattern != relay::kOpaque) { + // This rule for softmax: reduce + injective. + if (IsOutputBlock(op) && kind_ == relay::kCommReduce) { + kind_ = relay::kOutEWiseFusable; + } else { + kind_ = std::max(kind_, index_pair_pattern); + } + return; + } + + // Step 4. Checking if the block contains reduce axis by looking into block iterators. + bool has_reduction = false; + Array reduce_vars; + for (const IterVar& it : op->iter_vars) { + if (it->iter_type == kCommReduce) { + has_reduction = true; + reduce_vars.push_back(it->var); + } + } + + if (has_reduction) { + if (IsFMA(op->body)) { + // FMA is regards as kOutEWiseFusable, e.g. Matmul or Conv. + kind_ = std::max(kind_, relay::kOutEWiseFusable); + return; + } else { + for (size_t i = 0; i < loads_.size(); ++i) { + // If it's not a pure reduce, regards as kOutEWiseFusable. + // This rule works for pooling for now. + if (!IsPureReducePattern(reduce_vars, loads_[i]->indices)) { + kind_ = std::max(kind_, relay::kOutEWiseFusable); + return; + } + } + } + kind_ = std::max(kind_, relay::kCommReduce); + } else { + kind_ = relay::kOpaque; + } + } + + /********** Helper Functions **********/ + + /*! \brief Checking if two arrays contains same elements. */ + static bool IsSameArray(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); ++i) { + if (!lhs[i].same_as(rhs[i])) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices follows elemwise pattern. + * It's elemwise pattern iff load indices and store indices are the same. + * E.g A[i, j] = B[i, j] + */ + static bool IsElemwisePattern(const BufferStore& store, const BufferLoad& load) { + return IsSameArray(store->indices, load->indices); + } + + /*! + * \brief Checking the load indices and store indices follows broadcast pattern. + * It's broadcast pattern iff all load indices are in the store indices in order + * E.g. A[i, j] = B[i] is broadcast since all load indices(`i`) are in the store indices + * A[i, j] = B[i, k] is not broadcast since `k` are not in the store indices. + * A[i, j] = B[j, i] is not broadcast the load indices are not in the same order as store's + */ + static bool IsBroadcastPattern(const BufferStore& store, const BufferLoad& load) { + size_t ndim_load_buf = load->buffer->shape.size(); + size_t ndim_store_buf = store->buffer->shape.size(); + + for (size_t i = 0, j = 0; i < ndim_load_buf; ++i) { + if (is_const_int(load->buffer->shape[i], 1) && is_const_int(load->indices[i], 0)) { + // Skip unit load dimensions + // E.g. A[i, j] = B[1, j] is still broadcast + continue; + } + + // Try to find the i-th load indice in the store indices. + while (j < ndim_store_buf && !store->indices[j].same_as(load->indices[i])) { + ++j; + } + + // It's not broadcast if we cannot find load indices in the store indices in order. + if (j == ndim_store_buf) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices follows injective pattern. + * It's injective pattern iff all load indice vars are in the store indices, no matter orders. + * Note that we only support store indices are direct vars so far, which can be enhance later. + * E.g. A[i, j] = B[j, i] is injective. + * A[i, j] = B[i - j] is injective since the load indice vars are only i, j + */ + static bool IsInjectivePattern(const BufferStore& store, const BufferLoad& load) { + std::unordered_set vars; + for (const PrimExpr& store_index : store->indices) { + if (const auto* v = store_index.as()) { + vars.insert(v); + } else { + return false; + } + } + for (const PrimExpr& load_index : load->indices) { + // return false if there are vars used in load indices but not in store indices. + if (tir::UsesVar(load_index, [&vars](const tir::VarNode* var) { return !vars.count(var); })) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices allow data reuse. + * It allow data reuse iff there is any vars in load indices but they are not in store indices + * E.g. Store = A[i, j] and Load = B[i, j, k] allow data reuse. + * Store = A[i, j] and Load = B[i, j + k] allow data reuse. + */ + static bool IsAllowReusePattern(const BufferStore& store, const BufferLoad& load) { + std::unordered_set vars; + for (const PrimExpr& index : store->indices) { + if (const auto* v = index.as()) { + vars.insert(v); + } else { + return false; + } + } + for (const PrimExpr& index : load->indices) { + PreOrderVisit(index, [&](const ObjectRef& node) { + if (const auto* v = node.as()) { + if (vars.count(v)) { + vars.erase(v); + } + } + return true; + }); + } + return !vars.empty(); + } + + /*! \brief Checking if the stmt is multiply add. E.g. C[i, j] += A[i, k] * B[j, k] */ + static bool IsFMA(const Stmt& body) { + if (const auto* store = body.as()) { + if (const auto* add = store->value.as()) { + if (const auto* l = add->a.as()) { + if (const auto* r = add->b.as()) { + bool incremental = + store->buffer.same_as(l->buffer) && IsSameArray(store->indices, l->indices); + const auto* l_load = r->a.as(); + const auto* r_load = r->b.as(); + if (incremental && l_load && r_load) { + return IsAllowReusePattern(GetRef(store), GetRef(l_load)) && + IsAllowReusePattern(GetRef(store), GetRef(r_load)); + } + } + } + } + } + return false; + } + + /*! + * \brief Checking if it is pure reduce pattern. + * It's pure reduce pattern iff all reduces axis are directly reduce var + * E.g. A[i] = sum(B[i, j]) is pure reduce + * A[i] = sum(B[i, j + k]) is not pure reduce + * pooling is not pure reduce + */ + static bool IsPureReducePattern(Array reduce_loops, Array indices) { + for (const PrimExpr& e : indices) { + int id = -1; + if (UsesVar(e, [&](const tir::VarNode* var) { + for (size_t i = 0; i < reduce_loops.size(); ++i) { + if (reduce_loops[i].get() == var) { + id = i; + return true; + } + } + return false; + })) { + if (!reduce_loops[id].same_as(e)) { + return false; + } + } + } + return true; + } + + private: + /*! + * \brief The BufferStore node in the current block. + * \note We only support one BufferStore node in a block (ususally generated by TE compute) + */ + Optional store_; + /*! \brief The BufferLoad nodes in the current block. */ + Array loads_; + /*! \brief The result of op pattern. */ + relay::OpPatternKind kind_ = relay::kElemWise; + /*! \brief The buffers from function params. I.e. the input and output buffers. */ + std::unordered_set param_buffers_; + + public: + relay::OpPatternKind GetResult() { return kind_; } +}; + +relay::OpPatternKind AnalyzeOpPatternKind(const PrimFunc& func) { + PatternKindAnalyzer analyzer(func); + analyzer(func->body); + return analyzer.GetResult(); +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc new file mode 100644 index 0000000000..f3d9b4686b --- /dev/null +++ b/src/relax/analysis/udchain.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/analysis/udchain.cc + * \brief Implementation of use-def analysis. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class UDChain : public relax::ExprVisitor { + public: + // nullptr users means it is the output of the function. + std::map> to_users; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { to_users[op].insert(cur_user_); } + void VisitVarDef(const Var& var) override { to_users[var.get()] = {}; } + void VisitExpr_(const FunctionNode* op) override { ExprVisitor::VisitExpr_(op); } + + void VisitExpr_(const DataflowVarNode* op) override { + VisitExpr_(static_cast(op)); + } +}; + +std::pair>, runtime::Array> FunctionUseDef( + const Function& fn) { + UDChain udchain; + udchain.VisitExpr_(fn.get()); + + Map> user_map; + Array fn_outs; + + for (const auto& kv : udchain.to_users) { + Array uses{}; + uses.reserve(kv.second.size()); + for (const auto& v : kv.second) { + if (nullptr == v && + fn_outs.end() == std::find(fn_outs.begin(), fn_outs.end(), GetRef(kv.first))) { + fn_outs.push_back(GetRef(kv.first)); + } else { + uses.push_back(GetRef(v)); + } + } + user_map.Set(GetRef(kv.first), std::move(uses)); + } + return std::make_pair(std::move(user_map), std::move(fn_outs)); +} + +runtime::Map> DataflowBlockUseDef(const DataflowBlock& dfb) { + UDChain udchain; + udchain.VisitBindingBlock_(dfb.get()); + runtime::Map> ret; + for (const auto& kv : udchain.to_users) { + Array uses{}; + uses.reserve(kv.second.size()); + for (const auto& v : kv.second) uses.push_back(GetRef(v)); + ret.Set(GetRef(kv.first), std::move(uses)); + } + return ret; +} + +TVM_REGISTER_GLOBAL("relax.analysis.udchain").set_body_typed(DataflowBlockUseDef); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc new file mode 100644 index 0000000000..680a2a7261 --- /dev/null +++ b/src/relax/analysis/var2value.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { +class Var2ValAnalysis : public relax::ExprVisitor { + public: + tvm::runtime::Map var2value_; + void VisitBinding_(const VarBindingNode* binding) override { + var2value_.Set(binding->var, binding->value); + } +}; + +tvm::runtime::Map AnalyzeVar2Value(const Expr& expr) { + Var2ValAnalysis var2val_analysis; + var2val_analysis.VisitExpr(expr); + return std::move(var2val_analysis.var2value_); +} + +tvm::runtime::Map AnalyzeVar2Value(const DataflowBlock& dfb) { + Var2ValAnalysis var2val_analysis; + var2val_analysis.VisitBindingBlock_(dfb.get()); + return std::move(var2val_analysis.var2value_); +} + +tvm::runtime::Map AnalyzeVar2Value(const IRModule& m) { + Var2ValAnalysis var2val_analysis; + + for (const auto& it : m->functions) { + // visit relax.Function + if (auto* n = it.second.as()) { + var2val_analysis.VisitExpr(GetRef(n)); + } + } + + return std::move(var2val_analysis.var2value_); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.get_var2val")).set_body_typed([](const Function& f) { + return AnalyzeVar2Value(f); +}); + +class Name2BindingAnalysis : public relax::ExprVisitor { + public: + // runtime::Map is not suitable for doing in-place update. + // so we use standard container for internal usage. + std::map> name2bindings_; + void VisitBinding_(const VarBindingNode* binding) override { + const auto& vname = binding->var->name_hint(); + name2bindings_[vname].push_back(GetRef(binding)); + } + + void VisitBinding_(const MatchShapeNode* binding) override { + const auto& vname = binding->var->name_hint(); + name2bindings_[vname].push_back(GetRef(binding)); + } +}; + +Map> NameToBinding(const Function& fn) { + Name2BindingAnalysis analysis{}; + analysis.VisitExpr_(fn.get()); + return Map>(std::make_move_iterator(analysis.name2bindings_.begin()), + std::make_move_iterator(analysis.name2bindings_.end())); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.name_to_binding")).set_body_typed(NameToBinding); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc new file mode 100644 index 0000000000..8cf683bd66 --- /dev/null +++ b/src/relax/analysis/well_formed.cc @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/analysis/well_formed.cc + * \brief Check if the IRModule is well formed. If it's malformed, messages + * will be logged as Warning. This pass will check: + * 1. GlobalVars are defined before use. + * 2. Vars are defined before use. + * 3. Vars are defined exactly once. + * 4. Symbolic Vars are defined before use. + * 5. DataflowVars cannot be defined inside BindingBlock. + * 6. Vars defined in IfNode, except the return Var, are invisible + * out of the If body.(May change for new AST designs) + * 6. SeqExpr only serves as function body, or in the true and + * false branches in IfNode. + * 7. The IR is in ANF: + * (a) No nested call + * (b) The fields of the Tuple can only be Var/DataflowVar/Constant/ + * ShapeExpr/RuntimeDepShape/Tuple + */ +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +class WellFormedChecker; + +/*! \brief Helper to visit PrimExpr in the shape annotation and check if the symbolic vars in + * the PrimExpr are defined.*/ +class PrimExprVisitor : public tir::ExprVisitor { + public: + std::unordered_set symbolic_var_set_; + WellFormedChecker* checker_; + + explicit PrimExprVisitor(WellFormedChecker* checker) : checker_(checker) {} + + void VisitExpr_(const tir::VarNode* op); +}; + +/*! \brief Helper to implement well formed check.*/ +class WellFormedChecker : public relax::ExprVisitor { + public: + Optional diag_ctx; + + bool well_formed = true; + + explicit WellFormedChecker(const Optional& ctx) + : diag_ctx(ctx), prim_expr_visitor_(this) {} + + void Malformed(Diagnostic diag) { + well_formed = false; + LOG(WARNING) << "This IR is not well formed: " << diag->message; + } + + void RegisterGlobalVar(GlobalVar var) { global_var_set_.insert(var); } + + private: + void VisitExpr_(const GlobalVarNode* op) { + GlobalVar var = GetRef(op); + if (global_var_set_.count(var) == 0) { + Malformed(Diagnostic::Error(var->span) + << "GlobalVar " << op->name_hint << " is not defined."); + } + } + + void VisitExpr_(const TupleNode* op) { + for (size_t i = 0; i < op->fields.size(); i++) { + Expr expr = op->fields[i]; + if (expr.as() || expr.as() || expr.as() || + expr.as() || expr.as() || expr.as()) { + this->VisitExpr(expr); + } else { + Malformed(Diagnostic::Error(expr->span) + << "Tuple is not in ANF form, field " << i << " gets " << expr->GetTypeKey()); + } + } + + if (op->shape_) { + this->VisitExpr(Downcast(op->shape_.value())); + } + } + + void VisitExpr_(const VarNode* op) { + Var var = GetRef(op); + if (var_set_.count(var) == 0) { + Malformed(Diagnostic::Error(var->span) << "Var " << op->name_hint() << " is not defined."); + } + } + + void VisitExpr_(const DataflowVarNode* op) { + DataflowVar var = GetRef(op); + if (!is_dataflow_) { + Malformed(Diagnostic::Error(var->span) + << "DataflowVar " << op->name_hint() << " is used outside DataflowBlock."); + } + if (dataflow_var_set_.count(var) == 0) { + Malformed(Diagnostic::Error(var->span) + << "DataflowVar " << op->name_hint() << " is not defined."); + } + } + + void VisitExpr_(const FunctionNode* op) { + // save the var_set_ for local function + std::unordered_set previous_var_set_ = var_set_; + for (Var param : op->params) { + // register symbolic var defined in the shape annotation of function params + if (param->shape_) { + Expr var_shape = Downcast(param->shape_); + if (var_shape.as() || var_shape.as()) { + VisitExpr(var_shape); + } else { + for (PrimExpr expr : Downcast(var_shape)->values) { + if (expr.as()) { + prim_expr_visitor_.symbolic_var_set_.insert(Downcast(expr)); + } else { + prim_expr_visitor_(expr); + } + } + } + } + + this->VisitVarDef(param); + } + this->VisitBody(op->body); + var_set_ = previous_var_set_; + prim_expr_visitor_.symbolic_var_set_.clear(); + } + + void VisitExpr_(const CallNode* op) { + for (size_t i = 0; i < op->args.size(); i++) { + Expr arg = op->args[i]; + if (arg.as() || arg.as() || arg.as() || + arg.as() || arg.as() || arg.as() || + arg.as()) { + this->VisitExpr(arg); + } else { + Malformed(Diagnostic::Error(arg->span) + << "Call is not in ANF form, arg " << i << " gets " << arg->GetTypeKey()); + } + } + + if (op->shape_) { + this->VisitExpr(Downcast(op->shape_.value())); + } + } + + void VisitExpr_(const IfNode* op) { + this->VisitExpr(op->cond); + std::unordered_set previous_var_set_ = var_set_; + std::unordered_set previous_symbolic_var_set_ = + prim_expr_visitor_.symbolic_var_set_; + this->VisitBody(op->true_branch); + var_set_ = previous_var_set_; + prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; + this->VisitBody(op->false_branch); + var_set_ = previous_var_set_; + prim_expr_visitor_.symbolic_var_set_ = previous_symbolic_var_set_; + } + + void VisitExpr_(const ShapeExprNode* op) { + for (PrimExpr expr : op->values) { + // check if the symbolic vars in the expr are defined, e.g, 2 * m + prim_expr_visitor_(expr); + if (!expr.dtype().is_int()) { + Malformed(Diagnostic::Error(expr->span) + << "Shape expressions must be of integer type, but got " << expr.dtype()); + } + } + } + + void VisitExpr_(const SeqExprNode* op) { + Malformed(Diagnostic::Error(op->span) + << "SeqExpr only serves as the function body in FunctionNode, " + "or the true/false branch body in IfNode."); + } + + void VisitSeqExpr(const SeqExprNode* op) { + // a special call only if SeqExpr is the function body + // in FunctionNode or the true/false branch body in IfNode + for (BindingBlock block : op->blocks) { + this->VisitBindingBlock(block); + } + this->VisitExpr(op->body); + } + + void VisitBody(const Expr& expr) { + if (const SeqExprNode* seq_expr = expr.as()) { + this->VisitSeqExpr(seq_expr); + } else { + this->VisitExpr(expr); + } + } + + void VisitBinding_(const VarBindingNode* binding) { + this->VisitExpr(binding->value); + this->VisitVarDef(binding->var); + } + + void VisitBinding_(const MatchShapeNode* binding) { + this->VisitExpr(binding->value); + for (PrimExpr expr : binding->pattern) { + if (expr.as()) { + // register symbolic var implicitly defined in the pattern of MatchShape + prim_expr_visitor_.symbolic_var_set_.insert(Downcast(expr)); + } else { + // check if the symbolic var in the expr are defined, e.g, 2 * m + prim_expr_visitor_(expr); + } + } + + if (binding->var.defined()) { + this->VisitVarDef(binding->var); + } + } + + void VisitBindingBlock_(const DataflowBlockNode* block) { + is_dataflow_ = true; + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + is_dataflow_ = false; + dataflow_var_set_.clear(); + } + + void VisitVarDef_(const DataflowVarNode* var) { + if (!is_dataflow_) { + Malformed(Diagnostic::Error(var->span) + << "DataflowVar " << var->name_hint() << " is defined outside DataflowBlock."); + } + DataflowVar lv = GetRef(var); + if (dataflow_var_set_.count(lv) == 1) { + Malformed(Diagnostic::Error(var->span) + << "DataflowVar " << lv->name_hint() << " is defined more than once."); + } + // register DataflowVar + dataflow_var_set_.insert(lv); + } + + void VisitVarDef_(const VarNode* var) { + Var gv = GetRef(var); + if (var_set_.count(gv) == 1) { + Malformed(Diagnostic::Error(var->span) + << "Var " << gv->name_hint() << " is defined more than once."); + } + // register Var + var_set_.insert(gv); + } + + void VisitVarDef(const Var& var) { + if (const DataflowVarNode* lv_node = var.as()) { + VisitVarDef_(lv_node); + } else if (const VarNode* gv_node = var.as()) { + VisitVarDef_(gv_node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } + + if (var->shape_) { + VisitExpr(Downcast(var->shape_.value())); + } + } + + bool is_dataflow_ = false; + std::unordered_set global_var_set_; + std::unordered_set var_set_; + std::unordered_set dataflow_var_set_; + PrimExprVisitor prim_expr_visitor_; +}; + +void PrimExprVisitor::VisitExpr_(const tir::VarNode* op) { + tir::Var var = GetRef(op); + if (symbolic_var_set_.count(var) == 0) { + checker_->Malformed(Diagnostic::Error(var->span) + << "Symbolic Var " << var->name_hint << " is not defined."); + } +} + +bool WellFormed(const IRModule& m, Optional diag_ctx) { + WellFormedChecker well_formed_checker = WellFormedChecker(diag_ctx); + for (const auto& it : m->functions) { + // register GlobalVar in the IRModule first + well_formed_checker.RegisterGlobalVar(it.first); + } + + for (const auto& it : m->functions) { + // visit relax.Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + well_formed_checker.VisitExpr(func); + } + } + + return well_formed_checker.well_formed; +} + +TVM_REGISTER_GLOBAL(("relax.analysis.well_formed")).set_body_typed([](IRModule m) { + return WellFormed(m); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h new file mode 100644 index 0000000000..e44f2205f2 --- /dev/null +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -0,0 +1,430 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/backend/contrib/codegen_json/codegen_json.h + * \brief Utilities for json codegen and runtime + */ +#ifndef TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_ +#define TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../../../../runtime/contrib/json/json_node.h" +#include "../../../../runtime/contrib/json/json_runtime.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace backend { +namespace contrib { + +using namespace tvm::runtime::json; + +using ShapeVector = std::vector>; +using TypeVector = std::vector; +using JSONGraphObjectPtr = std::shared_ptr; + +/*! + * \brief Helper class to extract all attributes of a certain op and save them + * into text format. + */ +class OpAttrExtractor : public AttrVisitor { + public: + explicit OpAttrExtractor(JSONGraphObjectPtr node) : node_(node) {} + + template ::value>> + std::string Fp2String(const T value) { + std::ostringstream out; + out.precision(std::numeric_limits::max_digits10); + out << value; + return out.str(); + } + + void SetNodeAttr(const char* key, const std::vector& value) { + std::vector attr; + attr.emplace_back(value); + node_->SetAttr(key, attr); + } + + void Visit(const char* key, double* value) final { SetNodeAttr(key, {Fp2String(*value)}); } + + void Visit(const char* key, int64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, uint64_t* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, int* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, bool* value) final { SetNodeAttr(key, {std::to_string(*value)}); } + + void Visit(const char* key, std::string* value) final { SetNodeAttr(key, {*value}); } + + void Visit(const char* key, DataType* value) final { + if (!value->is_void()) { + SetNodeAttr(key, {runtime::DLDataType2String(*value)}); + } else { + SetNodeAttr(key, {""}); + } + } + + void Visit(const char* key, runtime::ObjectRef* value) final { + if (const auto* an = (*value).as()) { + std::vector attr; + for (size_t i = 0; i < an->size(); ++i) { + if (const auto* im = (*an)[i].as()) { + attr.push_back(std::to_string(im->value)); + } else if (const auto* fm = (*an)[i].as()) { + attr.push_back(Fp2String(fm->value)); + } else if (const auto* str = (*an)[i].as()) { + String s = GetRef(str); + attr.push_back(s); + } else { + LOG(FATAL) << "Not supported type: " << (*an)[i]->GetTypeKey(); + } + } + SetNodeAttr(key, attr); + } else if (!(*value).defined()) { // Skip NullValue + SetNodeAttr(key, std::vector{""}); + } else if (const auto* im = (*value).as()) { + SetNodeAttr(key, std::vector{std::to_string(im->value)}); + } else if (const auto* fm = (*value).as()) { + SetNodeAttr(key, std::vector{Fp2String(fm->value)}); + } else if (const auto* str = (*value).as()) { + String s = GetRef(str); + SetNodeAttr(key, std::vector{s}); + } else { + LOG(FATAL) << "Not yet supported type: " << (*value)->GetTypeKey() << ": " << *value; + } + } + + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "NDArray is not allowed in op attribute"; + } + + void Visit(const char* key, void** value) final { + LOG(FATAL) << "void pointer is not allowed in op attribute"; + } + + void Extract(Object* node) { + if (node) { + reflection_->VisitAttrs(node, this); + } + } + + private: + JSONGraphObjectPtr node_; + ReflectionVTable* reflection_ = ReflectionVTable::Global(); +}; + +/*! \brief Serialize a Relax expression to JSON. */ +class JSONSerializer + : public tvm::relax::backend::MemoizedExprTranslator> { + public: + /*! + * \brief Constructor + * + * \param symbol The symbol that represents the graph being converted. + * \param expr The Relax expression to be converted to the JSON form. + */ + JSONSerializer(const std::string& symbol, const Expr& expr) : symbol_(symbol), func_(expr) {} + + void serialize() { + relax::Function func = Downcast(func_); + + // First we convert all the parameters into input nodes. + for (const auto& param : func->params) { + auto node_ptr = std::make_shared(param->name_hint(), "input" /* op_type_ */); + memo_[param] = AddNode(node_ptr, param); + } + heads_ = VisitExpr(func->body); + } + + /*!\brief Return the required params. */ + Array GetParams() const { return params_; } + + /*!\brief Return the generated json. */ + std::string GetJSON() { + std::ostringstream os; + dmlc::JSONWriter writer(&os); + Save(&writer); + return os.str(); + } + + protected: + /*! + * \brief Add a node to graph. + * + * \param node A graph node. It is a shared pointer. Some attributes of it + * will be added, i.e. shape and type. These attributes are attached to + * the JSON graph in the end. + * \param expr The relax expression. + * \return A list of graph entry nodes. It the relax expr is a tuple type, we + * will flatten it. + */ + std::vector AddNode(JSONGraphObjectPtr node, const Expr& expr) { + auto checked_type = expr->checked_type(); + auto node_id = nodes_.size(); + nodes_.push_back(node); + std::vector ret; + ShapeVector shape; + TypeVector dtype; + + // Flatten tuple node. + if (const auto* tuple_type = checked_type.as()) { + for (size_t i = 0; i < tuple_type->fields.size(); ++i) { + const auto* tensor_type = tuple_type->fields[i].as(); + ICHECK(tensor_type) << "Expect DynTensorType, but received: ." + << tuple_type->fields[i]->GetTypeKey(); + ICHECK(expr->shape_.defined()) << "Expect shape to be defined. "; + ShapeExpr output_shape = Downcast(expr->shape_.value()); + ret.push_back(JSONGraphNodeEntry(node_id, i)); + shape.emplace_back(GetIntShape(output_shape->values)); + dtype.emplace_back(DType2String(tensor_type->dtype)); + } + node->SetNumOutput(tuple_type->fields.size()); + } else { + const auto* tensor_type = checked_type.as(); + ICHECK(tensor_type) << "Expect DynTensorType, but received: " << checked_type->GetTypeKey(); + ICHECK(expr->shape_.defined()) << "Expect shape to be defined. "; + ShapeExpr output_shape = Downcast(expr->shape_.value()); + + shape.emplace_back(GetIntShape(output_shape->values)); + dtype.emplace_back(DType2String(tensor_type->dtype)); + ret.push_back(JSONGraphNodeEntry(node_id, 0)); + } + std::vector shape_attrs; + shape_attrs.emplace_back(shape); + node->SetAttr("shape", shape_attrs); + + std::vector type_attrs; + type_attrs.emplace_back(dtype); + node->SetAttr("dtype", type_attrs); + return ret; + } + + void SetCallNodeAttribute(JSONGraphObjectPtr node, const CallNode* cn) { + if (cn->op.as()) { + OpAttrExtractor extractor(node); + const Object* call_attr = cn->attrs.get(); + extractor.Extract(const_cast(call_attr)); + } else if (const auto* fn = cn->op.as()) { + ICHECK(false); + auto pattern = fn->GetAttr(attr::kPartitionedFromPattern); + ICHECK(pattern.defined()); + std::vector values; + values.push_back(pattern.value()); + std::vector attr; + attr.emplace_back(values); + node->SetAttr("PartitionedFromPattern", attr); + } + } + + std::vector VisitBinding_(const VarBindingNode* binding) { + ICHECK_EQ(memo_.count(binding->var), 0); + memo_[binding->var] = VisitExpr(binding->value); + return VisitExpr(binding->value); + } + + std::vector VisitBinding_(const MatchShapeNode* binding) { + LOG(FATAL) << "JSON runtime currently doesn't shape expr\n"; + return {}; + } + + std::vector VisitBinding(const Binding& binding) { + std::vector nodes; + if (const auto* node = binding.as()) { + auto from_b = VisitBinding_(node); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } else if (const auto* node = binding.as()) { + auto from_b = VisitBinding_(node); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + return nodes; + } + + std::vector VisitBindingBlock(const BindingBlock& block) { + std::vector nodes; + if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); + } else if (const auto* node = block.as()) { + auto from_bb = VisitBindingBlock_(node); + nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return nodes; + } + + std::vector VisitBindingBlock_(const BindingBlockNode* block) { + std::vector nodes; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } + return nodes; + } + + std::vector VisitBindingBlock_(const DataflowBlockNode* block) { + std::vector nodes; + for (Binding binding : block->bindings) { + auto from_b = VisitBinding(binding); + nodes.insert(nodes.end(), from_b.begin(), from_b.end()); + } + return nodes; + } + + std::vector VisitExpr_(const SeqExprNode* op) { + std::vector nodes; + + for (BindingBlock block : op->blocks) { + auto from_bb = VisitBindingBlock(block); + // nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); + } + + auto from_body = VisitExpr(op->body); + nodes.insert(nodes.end(), from_body.begin(), from_body.end()); + + return nodes; + } + + std::vector VisitExprDefault_(const Object* op) { + LOG(FATAL) << "JSON runtime currently doesn't support " << op->GetTypeKey(); + return {}; + } + + std::vector VisitExpr_(const VarNode* vn) { + ICHECK(memo_.count(GetRef(vn))); + return memo_[GetRef(vn)]; + } + + std::vector VisitExpr_(const ConstantNode* cn) { + std::string name = symbol_ + "_const_" + std::to_string(params_.size()); + params_.push_back(name); + auto node = std::make_shared(name, "const" /* op_type_ */); + return AddNode(node, GetRef(cn)); + } + + std::vector VisitExpr_(const TupleNode* tn) { + std::vector fields; + for (const auto& field : tn->fields) { + auto ref = VisitExpr(field); + fields.insert(fields.end(), ref.begin(), ref.end()); + } + return fields; + } + + std::vector VisitExpr_(const CallNode* cn) { + Expr expr = GetRef(cn); + std::string name; + if (const auto* op_node = cn->op.as()) { + name = op_node->name; + } else if (const auto* fn = cn->op.as()) { + auto comp = fn->GetAttr(attr::kComposite); + ICHECK(comp.defined()) << "JSON runtime only supports composite functions."; + name = comp.value(); + } else { + LOG(FATAL) << "JSON runtime does not support calls to " << cn->op->GetTypeKey(); + } + + // TODO(@sunggg): Revisit when we have op naming convention. + // Currently, simply remove "relax." prefix to make it work. + name = std::string("tensorrt.") + name.substr(6); + + std::vector inputs; + for (const auto& arg : cn->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + auto node = std::make_shared(name, /* name_ */ + "kernel", /* op_type_ */ + inputs, 1 /* num_outputs_ */); + SetCallNodeAttribute(node, cn); + return AddNode(node, GetRef(cn)); + } + + std::vector VisitExpr_(const TupleGetItemNode* gtn) { + auto vtuple = VisitExpr(gtn->tuple); + return {vtuple[gtn->index]}; + } + + std::vector VisitExpr_(const FunctionNode* fn) { + ICHECK(fn->GetAttr(attr::kComposite).defined()) + << "JSON runtime only supports composite functions"; + + // FunctionNode should be handled by the caller. + return {}; + } + + /*! + * \brief Save to JSON graph + * + * \param writer A json writer + */ + void Save(dmlc::JSONWriter* writer) { + std::vector arg_nodes; + for (size_t i = 0; i < nodes_.size(); ++i) { + auto node = nodes_[i]; + if (node->IsLeaf()) { + arg_nodes.push_back(i); + } + } + size_t num_entry = 0; + std::vector node_row_ptr{0}; + for (auto node : nodes_) { + num_entry += node->GetNumOutput(); + node_row_ptr.push_back(num_entry); + } + writer->BeginObject(); + writer->WriteObjectKeyValue("nodes", nodes_); + writer->WriteObjectKeyValue("arg_nodes", arg_nodes); + writer->WriteObjectKeyValue("heads", heads_); + writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr); + writer->EndObject(); + } + + private: + /*! \brief The symbol that represents the json graph. */ + std::string symbol_; + /*! \brief The function to be serialized. */ + const Expr func_; + /*! \brief JSON graph nodes. */ + std::vector nodes_; + /*! \brief Output of the JSON graph. */ + std::vector heads_; + /*! \brief The list of required constants. */ + Array params_; +}; + +} // namespace contrib +} // namespace backend +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_BACKEND_CONTRIB_CODEGEN_JSON_CODEGEN_JSON_H_ diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc new file mode 100644 index 0000000000..26679f0455 --- /dev/null +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/tensorrt/codegen.cc + * \brief Implementation of the TensorRT JSON serializer. + */ +#include +// TODO(sunggg): add operator attribute when it's ready +// #include +#include + +#include +#include +#include + +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +#if TVM_GRAPH_EXECUTOR_TENSORRT +#include "NvInfer.h" +#endif + +namespace tvm { +namespace relax { +namespace contrib { + +/*! \brief Attributes to store the compiler options for TensorRT. */ +struct TensorRTCompilerConfigNode : public tvm::AttrsNode { + Array tensorrt_version; + bool use_implicit_batch; + size_t max_workspace_size; + bool remove_no_mac_subgraphs; + bool use_fp16; + bool use_uint8; + + TVM_DECLARE_ATTRS(TensorRTCompilerConfigNode, "relax.ext.attrs.TensorRTCompilerConfigNode") { + TVM_ATTR_FIELD(tensorrt_version) + .describe("TensorRT version as (major, minor, patch).") + .set_default(Array({6, 0, 1})); + TVM_ATTR_FIELD(use_implicit_batch).set_default(true); + TVM_ATTR_FIELD(max_workspace_size).set_default(size_t(1) << 30); + TVM_ATTR_FIELD(remove_no_mac_subgraphs).set_default(false); + TVM_ATTR_FIELD(use_fp16).set_default(false); + TVM_ATTR_FIELD(use_uint8).set_default(false); + } +}; + +class TensorRTCompilerConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorRTCompilerConfig, Attrs, + TensorRTCompilerConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(TensorRTCompilerConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("relax.ext.tensorrt.options", TensorRTCompilerConfig); + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr; +using OpAttrExtractor = backend::contrib::OpAttrExtractor; +using JSONSerializer = backend::contrib::JSONSerializer; + +class TensorRTJSONSerializer; + +/*! + * \brief Collect the constants and attributes from all operator calls in the body + * of a "Composite" function. + */ +class CollectFromCompositeFunctionBody : public ExprVisitor { + public: + explicit CollectFromCompositeFunctionBody(TensorRTJSONSerializer* serializer) + : serializer_(serializer), node_(std::make_shared()) {} + + void VisitExpr_(const ConstantNode* constant_node) final; + void VisitExpr_(const CallNode* call_node) final; + /* + void SetPadNodeAttribute(const CallNode* call_node) { + const auto* pad_attr = call_node->attrs.as(); + ICHECK(pad_attr); + auto p = pad_attr->pad_width; + const int dim_h = (p.size() == 5) ? 3 : 2; + const int dim_w = (p.size() == 5) ? 4 : 3; + std::vector padding = {std::to_string(p[dim_h][0].as()->value), + std::to_string(p[dim_w][0].as()->value), + std::to_string(p[dim_h][1].as()->value), + std::to_string(p[dim_w][1].as()->value)}; + std::vector padding_attr; + padding_attr.emplace_back(padding); + node_->SetAttr("padding", padding_attr); + } + + void SetStridedSliceNodeAttribute(const CallNode* call_node) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs && attrs->begin && attrs->end && attrs->strides) + << "StridedSlice must have static begin, end, and strides."; + const bool default_strides = + !attrs->strides.value().defined() || attrs->strides.value().size() == 0; + auto ishape = backend::GetShape(call_node->args[0]->checked_type()); + + auto process_slice_index = [](Integer x, int default_value, int dim_value) { + if (!x.defined()) return default_value; + int value = x.as()->value; + if (value < 0) value += dim_value; + return value; + }; + + std::vector start, size, strides; + for (size_t i = 0; i < attrs->begin.value().size(); ++i) { + const int begin_value = process_slice_index(attrs->begin.value()[i], 0, ishape[i]); + ICHECK_GE(begin_value, 0); + start.push_back(std::to_string(begin_value)); + const int stride_value = (default_strides || i >= attrs->strides.value().size() || + !attrs->strides.value()[i].defined()) + ? 1 + : attrs->strides.value()[i].as()->value; + ICHECK_GT(stride_value, 0); + strides.push_back(std::to_string(stride_value)); + int size_value; + if (attrs->slice_mode == "end") { + const int end_value = process_slice_index(attrs->end.value()[i], ishape[i], ishape[i]); + size_value = (end_value - begin_value + stride_value - 1) / stride_value; + } else if (attrs->slice_mode == "size") { + // with slice_mode = "size", attrs->end_value mean the size of the slice + int end_value = attrs->end.value()[i].as()->value; + size_value = (end_value == -1) ? ishape[i] - begin_value : end_value; + } else { + LOG(FATAL) << "Unexpected slice_mode " << attrs->slice_mode << ", expected end or size"; + throw; + } + ICHECK_GT(size_value, 0); + size.push_back(std::to_string(size_value)); + } + std::vector start_attr, size_attr, strides_attr; + start_attr.emplace_back(start); + size_attr.emplace_back(size); + strides_attr.emplace_back(strides); + node_->SetAttr("start", start_attr); + node_->SetAttr("size", size_attr); + node_->SetAttr("strides", strides_attr); + } + + void SetSplitNodeAttribute(const CallNode* call_node) { + const auto* split_attr = call_node->attrs.as(); + ICHECK(split_attr); + + std::vector indices_or_sections; + std::vector mode; + std::vector axis = {std::to_string(split_attr->axis)}; + if (const auto* sections = split_attr->indices_or_sections.as()) { + mode.emplace_back("sections"); + indices_or_sections.emplace_back(std::to_string(sections->value)); + } else { + mode.emplace_back("indices"); + auto indices = Downcast>(split_attr->indices_or_sections); + for (const auto& i : indices) { + indices_or_sections.emplace_back(std::to_string(i->value)); + } + } + + std::vector indices_or_sections_attr; + std::vector mode_attr; + std::vector axis_attr; + indices_or_sections_attr.emplace_back(indices_or_sections); + mode_attr.emplace_back(mode); + axis_attr.emplace_back(axis); + node_->SetAttr("indices_or_sections", indices_or_sections_attr); + node_->SetAttr("mode", mode_attr); + node_->SetAttr("axis", axis_attr); + } + + void SetGenericAttributes(const CallNode* call_node) { + OpAttrExtractor extractor(node_); + const Object* attr_obj = call_node->attrs.get(); + extractor.Extract(const_cast(attr_obj)); + } + */ + TensorRTJSONSerializer* serializer_; + /*! \brief Accumulated translated arguments. */ + std::vector args_; + /*! + * \brief Temporary node into which we'll accumulate attributes. Ideally this would be the + * final JSONGraphNode however we don't yet know how many inputs that will have. + */ + JSONGraphObjectPtr node_; +}; + +/*! + * \brief Generates an TensorRTModule from a relax expression by serializing the expression to a + * json representation. TensorRT is not required here because use of TensorRT APIs is deferred until + * runtime. + */ +class TensorRTJSONSerializer : public JSONSerializer { + public: + TensorRTJSONSerializer(const std::string& symbol, const Expr& expr) + : JSONSerializer(symbol, expr) {} + + using JSONSerializer::VisitExpr_; + + std::vector VisitExpr_(const CallNode* call_node) final { + // The call must be to an inline "Composite" function + const auto* function_node = call_node->op.as(); + // ICHECK(function_node != nullptr); + if (!function_node) return JSONSerializer::VisitExpr_(call_node); + + auto opt_composite = function_node->GetAttr(attr::kComposite); + ICHECK(opt_composite.defined()); + std::string name = opt_composite.value(); + + // Collect the constants and attributes of all operator calls inside the composite body. + CollectFromCompositeFunctionBody collector(this); + collector.VisitExpr(function_node->body); + + // Capture the args to the "Composite" function as inputs for this node. + std::vector inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + + // Capture constants from the composite function body as additional inputs for this node. + for (const auto& node : collector.args_) { + inputs.emplace_back(node); + } + // TODO(@sunggg): Revisit when we have op naming convention. + // Currently, simply remove "relax." prefix to make it work. + name = std::string("tensorrt.") + name.substr(6); + // Create the final node. + auto node = std::make_shared(name, + /*op_type=*/"kernel", inputs, + /*num_output=*/1); + + // Transfer attributes from the collector's node to the final node. + node->CaptureAttrs(*collector.node_); + + // Capture global settings on the JSON node. + SaveGlobalAttributes(node); + + VLOG(1) << name << " has " << node->GetInputs().size() << " inputs"; + + return AddNode(node, GetRef(call_node)); + } + + static void SaveGlobalAttributes(std::shared_ptr node) { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relax.ext.tensorrt.options"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3); + std::vector tensorrt_version = { + std::to_string(cfg.value()->tensorrt_version[0].IntValue()), + std::to_string(cfg.value()->tensorrt_version[1].IntValue()), + std::to_string(cfg.value()->tensorrt_version[2].IntValue())}; + std::vector use_implicit_batch = {std::to_string(cfg.value()->use_implicit_batch)}; + std::vector max_workspace_size = {std::to_string(cfg.value()->max_workspace_size)}; + std::vector use_fp16 = {std::to_string(cfg.value()->use_fp16)}; + std::vector use_uint8 = {std::to_string(cfg.value()->use_uint8)}; + std::vector tensorrt_version_attr, use_implicit_batch_attr, max_workspace_size_attr, + use_fp16_attr, use_uint8_attr; + tensorrt_version_attr.emplace_back(tensorrt_version); + use_implicit_batch_attr.emplace_back(use_implicit_batch); + max_workspace_size_attr.emplace_back(max_workspace_size); + use_fp16_attr.emplace_back(use_fp16); + use_uint8_attr.emplace_back(use_uint8); + node->SetAttr("tensorrt_version", tensorrt_version_attr); + node->SetAttr("use_implicit_batch", use_implicit_batch_attr); + node->SetAttr("max_workspace_size", max_workspace_size_attr); + node->SetAttr("use_fp16", use_fp16_attr); + node->SetAttr("use_uint8", use_uint8_attr); + } +}; + +void CollectFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) { + for (const auto& entry : serializer_->VisitExpr(GetRef(constant_node))) { + args_.emplace_back(entry); + } +} + +void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { + const auto* op_node = call_node->op.as(); + ICHECK(op_node != nullptr); + std::string name = op_node->name; + /* + // TODO(@sunggg): revisit when relax supports these ops. + if (name == "nn.pad") { + SetPadNodeAttribute(call_node); + } else if (name == "strided_slice") { + SetStridedSliceNodeAttribute(call_node); + } else if (name == "split") { + SetSplitNodeAttribute(call_node); + } else { + SetGenericAttributes(call_node); + } + */ + ExprVisitor::VisitExpr_(call_node); +} + +/*! + * \brief Create a runtime module for TensorRT. + * \param ref The ext_func Relay expression/module to be executed using extern ops. + * \return A runtime module. + */ +runtime::Module TensorRTCompiler(const ObjectRef& ref) { + ICHECK(ref->IsInstance()) << "The input ref is expected to be a Relax function."; + Function func = Downcast(ref); + std::string func_name = backend::GetExtSymbol(func); + + VLOG(1) << "TensorRT partition:" << std::endl << PrettyPrint(func); + TensorRTJSONSerializer serializer(func_name, func); + serializer.serialize(); + std::string graph_json = serializer.GetJSON(); + VLOG(1) << "TensorRT JSON:" << std::endl << graph_json; + auto param_names = serializer.GetParams(); + const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create"); + ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; + VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'"; + runtime::Module lib = (*pf)(func_name, graph_json, param_names); + return lib; +} + +TVM_REGISTER_GLOBAL("relax.ext.tensorrt").set_body_typed(TensorRTCompiler); + +/*! + * \brief Check whether TensorRT graph executor is enabled. + * \return True if enabled, False if not. + */ +inline constexpr bool IsTensorRTRuntimeEnabled() { +#if TVM_GRAPH_EXECUTOR_TENSORRT + return true; +#else + return false; +#endif // TVM_GRAPH_EXECUTOR_TENSORRT +} + +/*! + * \brief Get TensorRT version that TVM is built against. + * \return Array of three integers for major, minor, and patch, or empty array if TensorRT graph + * runtime is not enabled. + */ +Array GetTensorRTVersion() { +#if TVM_GRAPH_EXECUTOR_TENSORRT + return {Integer(NV_TENSORRT_MAJOR), Integer(NV_TENSORRT_MINOR), Integer(NV_TENSORRT_PATCH)}; +#else + return {}; +#endif // TVM_GRAPH_EXECUTOR_TENSORRT +} + +TVM_REGISTER_GLOBAL("relax.is_tensorrt_runtime_enabled").set_body_typed(IsTensorRTRuntimeEnabled); +TVM_REGISTER_GLOBAL("relax.get_tensorrt_version").set_body_typed(GetTensorRTVersion); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h new file mode 100644 index 0000000000..85d09f9cd8 --- /dev/null +++ b/src/relax/backend/contrib/utils.h @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/backend/contrib/utils.h + * \brief Utils function for backend + */ +#ifndef TVM_RELAX_BACKEND_CONTRIB_UTILS_H_ +#define TVM_RELAX_BACKEND_CONTRIB_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../../runtime/meta_data.h" +#include "../../../target/metadata.h" +#include "tvm/runtime/ndarray.h" + +namespace tvm { +namespace relax { +namespace backend { + +/*! + * \brief A simple wrapper around ExprFunctor for a single argument case. + * The result of visit is memoized. + */ +template +class MemoizedExprTranslator : public ::tvm::relax::ExprFunctor { + using BaseFunctor = ::tvm::relax::ExprFunctor; + + public: + /*! \brief virtual destructor */ + virtual ~MemoizedExprTranslator() {} + + /*! + * \brief The memoized call. + * \param n The expression node. + * \return The result of the call + */ + virtual OutputType VisitExpr(const Expr& n) { + ICHECK(n.defined()); + auto it = memo_.find(n); + if (it != memo_.end()) { + return it->second; + } + auto res = BaseFunctor::VisitExpr(n); + memo_[n] = res; + return res; + } + + protected: + /*! \brief Internal map used for memoization. */ + std::unordered_map memo_; +}; + +/*! + * \brief Get the external symbol of the Relax function name. + * + * \param func The provided function. + * \return An external symbol. + */ +inline std::string GetExtSymbol(const Function& func) { + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(name_node.defined()) << "Fail to retrieve external symbol."; + return std::string(name_node.value()); +} + +/*! + * \brief Get the Packed Func + * + * \param func_name + * \return const PackedFunc* + */ +inline const PackedFunc* GetPackedFunc(const std::string& func_name) { + return tvm::runtime::Registry::Get(func_name); +} + +/*! + * \brief Extract shape from an IndexExpr array to std::vector + * + * \param shape The shape in Array + * \return The converted shape in std::vector + */ + +inline std::vector GetIntShape(const Array& shape) { + std::vector ret; + for (const auto& dim : shape) { + const int64_t* pval = tir::as_const_int(dim); + ret.push_back(pval ? *pval : -1); + } + return ret; +} + +/*! + * \brief Convert type to string + * + * \param typ + * \return std::string string format of type + */ +inline std::string DType2String(const tvm::DataType dtype) { + std::ostringstream os; + if (dtype.is_float()) { + os << "float"; + } else if (dtype.is_int()) { + os << "int"; + } else if (dtype.is_uint()) { + os << "uint"; + } else if (dtype.is_bfloat16()) { + os << "bfloat"; + } else if ((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) { + os << "custom[" + << (*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator std::string() + << "]"; + } else { + LOG(FATAL) << "Unknown type with code " << static_cast(dtype.code()); + } + os << dtype.bits(); + return os.str(); +} + +} // namespace backend +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_CONTRIB_UTILS_H_ diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc new file mode 100644 index 0000000000..beb3950af1 --- /dev/null +++ b/src/relax/backend/task_extraction.cc @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace backend { + +using tvm::meta_schedule::ExtractedTask; + +/*! + * \brief Extract the Meta-Schedule tuning task from a given IRModule. + * \note + * 1. The task extractor is responsible for task deduplication. The + * deduplication is achieved by comparing structural hashes of PrimFuncs. + * 2. For a PrimFunc, the weight of its corresponding task is the number + * of times it called by op Call-TIR. Say in an IRModule there are three + * PrimFuncs `fn1`, `fn2` and `fn3` sharing the same structural hash. + * Suppose `fn1` is called by 5 Call-TIR ops among all Relax function, + * `fn2` is called by 3 Call-TIR and `fn3` is called by 5 Call-TIR. + * Then we will have a ExtractedTask for all three functions, whose weight + * is 5 + 3 + 2 = 10. + */ +class TaskExtractor : public ExprVisitor { + public: + static Array ExtractTask(IRModule mod, Target target) { + TaskExtractor extractor(mod, target); + // We go through each Relax function in the module. + for (const auto& kv : mod->functions) { + if (const auto* func = kv.second.as()) { + extractor(GetRef(func)); + } + } + return std::move(extractor.tasks_); + } + + private: + explicit TaskExtractor(IRModule mod, Target target) + : mod_(std::move(mod)), target_(std::move(target)) { + normalize_mod_func_ = runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + } + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + // TODO(@tvm-team): When we differentiate the call for tir function and packed function, + // this logic should be changed accordingly. + if (!call->op.same_as(call_tir_op)) { + // Since the Relax function is of A-normal form, the arguments of this call cannot be another + // Calls. And hence we do not need to recurse into this Call. + return; + } + + // Do not extract external function + if (call->args[0].as()) { + return; + } + + const GlobalVar& global_var = Downcast(call->args[0]); + const tir::PrimFunc& func = Downcast(mod_->Lookup(global_var)); + + auto it = func2task_.find(func); + if (it != func2task_.end()) { + it->second->weight += 1; + return; + } + + IRModule tir_mod = (*normalize_mod_func_)(func); + ExtractedTask task(/*task_name=*/global_var->name_hint, // + /*mod=*/tir_mod, // + /*target=*/target_, // + /*dispatched=*/{tir_mod}, // + /*weight=*/1); + tasks_.push_back(task); + func2task_.emplace(func, task); + } + + IRModule mod_; + Target target_; + Array tasks_; + std::unordered_map func2task_; + const runtime::PackedFunc* normalize_mod_func_; +}; + +TVM_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask") + .set_body_typed([](IRModule mod, Target target) { + return TaskExtractor::ExtractTask(std::move(mod), std::move(target)); + }); + +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc new file mode 100644 index 0000000000..d8d09ddd88 --- /dev/null +++ b/src/relax/backend/vm/codegen_vm.cc @@ -0,0 +1,581 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/codegen_vm.cc + * \brief A codegen to generate VM executable from a Relax IRModule. + */ + +#include "codegen_vm.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../../target/metadata_module.h" +#include "../../../target/source/codegen_source_base.h" + +namespace tvm { +namespace relax { +namespace relax_vm { + +using namespace relax; + +// Helper function to get the function name of the registered packed function implementation of +// relax operator. +FCallPacked GetPackedFuncName(const Call& call) { + auto op_map = Op::GetAttrMap("FCallPacked"); + if (call->op.as()) { + Op op = Downcast(call->op); + if (op_map.count(op)) { + return op_map[op]; + } + } + return {}; +} + +/*! + * \brief A class to generate VM executable for Relax functions. + */ +class CodeGenVM : public ExprFunctor { + public: + explicit CodeGenVM(ExecBuilderNode* builder) { builder_ = GetRef(builder); } + + protected: + size_t NewRegister() { return registers_num_++; } + Instruction::Arg VisitExpr_(const FunctionNode* func_node) { + Optional gsymbol = func_node->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. " + "Did you forget to apply LambdaLift pass?"; + + Array param_names; + for (Var param : func_node->params) { + param_names.push_back(param->name_hint()); + } + + builder_->EmitFunction(gsymbol.value(), func_node->params.size(), param_names); + + for (Var param : func_node->params) { + Instruction::Arg reg = this->VisitExpr(param); + this->var_register_map_.insert({param, reg.data}); + } + Instruction::Arg ret = ExprFunctor::VisitExpr(func_node->body); + builder_->EmitRet(ret.data); + registers_num_ = 0; + return ret; + } + + Instruction::Arg VisitExpr_(const SeqExprNode* op) { + for (auto block : op->blocks) { + for (Binding binding : block->bindings) { + ICHECK(binding->IsInstance()); + Expr value = Downcast(binding)->value; + Var var = Downcast(binding)->var; + Instruction::Arg reg = this->VisitExpr(value); + this->var_register_map_.insert({var, reg.data}); + } + } + + Instruction::Arg ret_reg = this->VisitExpr(op->body); + return ret_reg; + } + + Instruction::Arg VisitExpr_(const CallNode* call_node) { + if (call_node->op.as()) { + // special case generate for the intrinsics whose attribute fields + // cannot be represented by args in the CallNode + const Call& call = GetRef(call_node); + FCallPacked name = GetPackedFuncName(call); + if (!name.empty()) { + // If the operator has a registered packed function implementation, emit call to that packed + // function. + return EmitPackedFuncCall(call, name); + } else if (call_node->op == alloc_storage_op_) { + return EmitAllocStorage(call); + } else if (call_node->op == alloc_tensor_op_) { + return EmitAllocTensor(call); + } else if (call_node->op == store_shape_op_ || call_node->op == load_shape_op_) { + return EmitShape(call); + } else if (call_node->op == call_tir_dyn_op_) { + return EmitTirDynOp(call); + } else if (call_node->op == make_closure_op_) { + return EmitAllocClosure(call); + } else if (call_node->op == invoke_closure_op_) { + return EmitInvokeClosure(call); + } else { + // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those + // ops are handled in a pass when lowering them to TIR. + LOG(FATAL) << "CodeGenVM cannot handle this intrinsic now:\n" << call_node->op; + } + } + String name; + if (auto* extern_func = call_node->op.as()) { + name = extern_func->global_symbol; + } else if (auto* gvar = call_node->op.as()) { + // GlobalVar can be reference to a Relax function or a TIR primfunc + name = gvar->name_hint; + } else { + LOG(FATAL) << "CodeGenVM does not support calls to " << call_node->op->GetTypeKey(); + } + std::vector args; + // For extern function `vm.builtin.alloc_shape_heap` we must pass vm register as the first + // argument to find the device in which shape heap should be allocated. + if (name == "vm.builtin.alloc_shape_heap") { + args.push_back(Instruction::Arg(Instruction::kRegister, Instruction::kVMRegister)); + } + std::vector converted_args = ConvertArgs(GetRef(call_node)); + args.insert(args.end(), converted_args.begin(), converted_args.end()); + size_t dst_register = NewRegister(); + builder_->EmitCall(name, args, dst_register); + return Instruction::Arg(Instruction::kRegister, dst_register); + } + + Instruction::Arg VisitExpr_(const IfNode* op) { + const If& ife = GetRef(op); + // Get the executable from exec_builder + ObjectPtr exec_ = builder_->Get(); + + // Visit the condition expression + Instruction::Arg cond_reg = this->VisitExpr(ife->cond); + // Record the offset of If instruction + size_t if_offset = exec_->instr_offset.size(); + + builder_->EmitIf(cond_reg.value(), 3); + size_t num_instr = exec_->instr_offset.size(); + Instruction::Arg true_reg = this->VisitExpr(ife->true_branch); + // Reserve a register for return + size_t merge_register = NewRegister(); + // Copy the output from true branch to merge register + builder_->EmitCall("vm.builtin.copy", {true_reg}, merge_register); + + // Record the offset of Goto instruction + size_t goto_offset = exec_->instr_offset.size(); + + builder_->EmitGoto(1); + + // Calculate the false offset of If + size_t false_offset = exec_->instr_offset.size() - num_instr + 1; + + Instruction::Arg false_reg = this->VisitExpr(ife->false_branch); + // Copy the output data of false branch to merge register + builder_->EmitCall("vm.builtin.copy", {false_reg}, merge_register); + + // Update the offsets of the If instruction emitted above + // Jump to the behind of the next goto instruction + exec_->SetInstructionData(if_offset, 2, static_cast(false_offset)); + // Update the pc_offset of Goto instruction + // Jump over the false branch + size_t pc_offset = exec_->instr_offset.size() - goto_offset; + exec_->SetInstructionData(goto_offset, 1, static_cast(pc_offset)); + return Instruction::Arg(Instruction::kRegister, merge_register); + } + + Instruction::Arg VisitExpr_(const VarNode* op) { + auto it = this->var_register_map_.find(GetRef(op)); + if (it != this->var_register_map_.end()) { + return Instruction::Arg(Instruction::kRegister, it->second); + } else { + return Instruction::Arg(Instruction::kRegister, NewRegister()); + } + } + + Instruction::Arg VisitExpr_(const ConstantNode* op) { + TVMRetValue constant_data; + constant_data = op->data; + Index index = this->builder_->EmitConstant(constant_data); + + size_t dst_register = NewRegister(); + std::vector args; + args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + builder_->EmitCall("vm.builtin.copy", args, dst_register); + return Instruction::Arg(Instruction::kRegister, dst_register); + } + + Instruction::Arg VisitExpr_(const ShapeExprNode* op) { + ShapeExpr sh = GetRef(op); + ICHECK(IsConstantShape(sh)) << "should only use constant shape after shape lowering: " + << sh->values; + std::vector shape; + for (PrimExpr e : sh->values) { + shape.push_back(Downcast(e)->value); + } + auto shape_tuple = ShapeTuple(shape); + TVMRetValue shape_tuple_value; + shape_tuple_value = shape_tuple; + Index index = builder_->EmitConstant(shape_tuple_value); + return Instruction::Arg(Instruction::kConstIdx, index); + } + + Instruction::Arg VisitExpr_(const TupleNode* op) { + Tuple tuple = GetRef(op); + std::vector args; + for (auto arg : tuple->fields) { + args.push_back(this->VisitExpr(arg)); + } + size_t dst_register = NewRegister(); + builder_->EmitCall("runtime.Tuple", args, dst_register); + + return Instruction::Arg(Instruction::kRegister, dst_register); + } + + Instruction::Arg VisitExpr_(const TupleGetItemNode* op) { + TupleGetItem expr = GetRef(op); + std::vector args = {this->VisitExpr(expr->tuple)}; + + std::vector tuple_index = {expr->index}; + auto shape_tuple = ShapeTuple(tuple_index); + TVMRetValue shape_tuple_value; + shape_tuple_value = shape_tuple; + Index index = builder_->EmitConstant(shape_tuple_value); + args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + + size_t dst_register = NewRegister(); + builder_->EmitCall("vm.runtime.TupleGetItem", args, dst_register); + + return Instruction::Arg(Instruction::kRegister, dst_register); + } + + Instruction::Arg EmitAllocStorage(const Call& call_node) { + // Handle args of the call + std::vector args; + args.push_back(Instruction::Arg(Instruction::kVMRegister)); + for (Expr arg : call_node->args) { + args.push_back(ConvertArg(arg)); + } + + // Handle attrs of the call + auto alloc_attrs = call_node->attrs.as(); + ICHECK(alloc_attrs != nullptr) << "must be VMAllocStorageAttrs"; + Index runtime_device_index = alloc_attrs->runtime_device_index; + args.push_back(Instruction::Arg(Instruction::kImmediate, runtime_device_index)); + DataType dtype = alloc_attrs->dtype; + TVMRetValue data_type; + data_type = dtype; + Index index = this->builder_->EmitConstant(data_type); + args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + + size_t dst_register = NewRegister(); + builder_->EmitCall("vm.builtin.alloc_storage", args, dst_register); + return Instruction::Arg(Instruction::kRegister, dst_register); + } + + Instruction::Arg EmitAllocTensor(const Call& call_node) { + ICHECK_EQ(call_node->args.size(), 2); + std::vector args; + args.reserve(4); + // Handle `self` + args.push_back(ConvertArg(call_node->args[0])); + // Handle `offset` + auto alloc_attrs = call_node->attrs.as(); + ICHECK(alloc_attrs != nullptr) << "must be VMAllocTensorAttrs"; + int offset = alloc_attrs->offset; + args.push_back(Instruction::Arg(Instruction::kImmediate, offset)); + // Handle `shape` + args.push_back(ConvertArg(call_node->args[1])); + // Handle `dtype` + DataType dtype = alloc_attrs->dtype; + TVMRetValue data_type; + data_type = dtype; + Index index = this->builder_->EmitConstant(data_type); + args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + size_t dst_register = NewRegister(); + builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_register); + return Instruction::Arg(Instruction::kRegister, dst_register); + } + + Instruction::Arg EmitShape(const Call& call_node) { + // Handle args of the call + std::vector args; + for (Expr arg : call_node->args) { + args.push_back(ConvertArg(arg)); + } + + // Handle attrs of the call + auto shape_attrs = call_node->attrs.as(); + ICHECK(shape_attrs != nullptr) << "must be ShapeHeapAttrs"; + std::vector indices_vec; + for (Integer ind : shape_attrs->indices) { + indices_vec.push_back(ind.IntValue()); + } + ShapeTuple indices = ShapeTuple(indices_vec); + TVMRetValue indices_const; + indices_const = indices; + Index index = builder_->EmitConstant(indices_const); + args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + + size_t dst_register = NewRegister(); + if (call_node->op == store_shape_op_) { + builder_->EmitCall("vm.builtin.store_shape", args, dst_register); + } else if (call_node->op == load_shape_op_) { + builder_->EmitCall("vm.builtin.load_shape", args, dst_register); + } + return Instruction::Arg(Instruction::kRegister, dst_register); + } + + Instruction::Arg EmitTirDynOp(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance()); + ICHECK(call_node->args[1]->IsInstance()); + + auto gv = Downcast(call_node->args[0]); + auto tir_args = Downcast(call_node->args[1]); + auto func_name = gv->name_hint; + + TVMRetValue func_name_constant; + func_name_constant = func_name; + auto func_name_index = builder_->EmitConstant(func_name_constant); + + std::vector args; + args.push_back(Instruction::Arg(Instruction::kVMRegister)); + args.push_back(Instruction::Arg(Instruction::kConstIdx, func_name_index)); + for (Expr arg : tir_args->fields) { + args.push_back(ConvertArg(arg)); + } + + size_t dst_register = NewRegister(); + + builder_->EmitCall("vm.call_tir_dyn", args, dst_register); + return Instruction::Arg(Instruction::kRegister, dst_register); + } + + template + Instruction::Arg EmitConstantFromValue(T value) { + TVMRetValue tvm_value; + tvm_value = value; + Index index = builder_->EmitConstant(tvm_value); + return Instruction::Arg(Instruction::kConstIdx, index); + } + + // Emit the `call_node` attributes as constants and append these constants to `args` vector. + void AppendAttrsAsConstants(const Call& call_node, std::vector& args) { + auto attrs = call_node->attrs; + if (!attrs.defined()) return; + + if (call_node->op == unique_op_) { + auto unique_attrs = call_node->attrs.as(); + args.push_back(EmitConstantFromValue(unique_attrs->sorted)); + args.push_back(EmitConstantFromValue(unique_attrs->return_inverse)); + args.push_back(EmitConstantFromValue(unique_attrs->return_counts)); + args.push_back(EmitConstantFromValue(unique_attrs->dim)); + return; + } + if (call_node->op == print_op_) { + auto print_attrs = call_node->attrs.as(); + // format string is the first argument + args.insert(args.begin(), EmitConstantFromValue(print_attrs->format)); + return; + } + if (call_node->op == assert_op_) { + auto assert_attrs = call_node->attrs.as(); + // format string comes before the format args + args.insert(args.begin() + 1, EmitConstantFromValue(assert_attrs->format)); + return; + } + LOG(FATAL) << "Support for attributes of Op " << call_node->op + << " has not been implemented yet."; + return; + } + + // Emits call to packed function `name` with arguments copied over from `call_node` args and + // attributes. + Instruction::Arg EmitPackedFuncCall(const Call& call_node, const FCallPacked& name) { + std::vector args; + args = ConvertArgs(call_node); + AppendAttrsAsConstants(call_node, args); + size_t dst_register = NewRegister(); + builder_->EmitCall(name, args, dst_register); + return Instruction::Arg(Instruction::kRegister, dst_register); + } + + Instruction::Arg EmitAllocClosure(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance()); + ICHECK(call_node->args[1]->IsInstance()); + + auto gv = Downcast(call_node->args[0]); + auto closure_args = Downcast(call_node->args[1]); + auto func_name = gv->name_hint; + + TVMRetValue func_name_constant; + func_name_constant = func_name; + auto func_name_index = builder_->EmitConstant(func_name_constant); + + std::vector args; + args.push_back(Instruction::Arg(Instruction::kConstIdx, func_name_index)); + for (Expr arg : closure_args->fields) { + args.push_back(ConvertArg(arg)); + } + + size_t dst_register = NewRegister(); + builder_->EmitCall("vm.builtin.alloc_closure", args, dst_register); + return Instruction::Arg(Instruction::kRegister, dst_register); + } + + Instruction::Arg EmitInvokeClosure(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance()); + ICHECK(call_node->args[1]->IsInstance()); + + std::vector args; + // VM is utilized to help get the Function in builtin packedfunc + args.push_back(Instruction::Arg(Instruction::kVMRegister)); + + auto lv = Downcast(call_node->args[0]); + auto it = this->var_register_map_.find(lv); + if (it != this->var_register_map_.end()) { + args.push_back(Instruction::Arg(Instruction::kRegister, it->second)); + } else { + args.push_back(Instruction::Arg(Instruction::kRegister, registers_num_)); + } + + // args for the invoke_closure + auto invoke_closure_args = Downcast(call_node->args[1]); + for (Expr arg : invoke_closure_args->fields) { + args.push_back(ConvertArg(arg)); + } + + size_t dst_register = NewRegister(); + builder_->EmitCall("vm.builtin.invoke_closure", args, dst_register); + return Instruction::Arg(Instruction::kRegister, dst_register); + } + + bool IsConstantShape(ShapeExpr shape) const { + for (PrimExpr e : shape->values) { + if (!e->IsInstance()) { + return false; + } + } + return true; + } + + Instruction::Arg ConvertArg(Expr arg) { + if (arg->IsInstance()) { + Var var = Downcast(arg); + auto reg = this->var_register_map_.find(Downcast(arg)); + ICHECK(reg != this->var_register_map_.end()) << var->name_hint() << "(" << var << ")" + << " not in the register map."; + return Instruction::Arg(Instruction::kRegister, reg->second); + } else if (arg->IsInstance()) { + ShapeExpr sh = Downcast(arg); + ICHECK(IsConstantShape(sh)) << "should only use constant shape after shape lowering: " + << sh->values; + std::vector shape; + for (PrimExpr e : sh->values) { + shape.push_back(Downcast(e)->value); + } + auto shape_tuple = ShapeTuple(shape); + TVMRetValue shape_tuple_value; + shape_tuple_value = shape_tuple; + Index index = builder_->EmitConstant(shape_tuple_value); + return Instruction::Arg(Instruction::kConstIdx, index); + } else if (arg->IsInstance()) { + TVMRetValue constant_data; + constant_data = Downcast(arg)->data; + Index index = builder_->EmitConstant(constant_data); + return Instruction::Arg(Instruction::kConstIdx, index); + } else { + LOG(FATAL) << "CodeGenVM does not support this argument type:\n" << arg->GetTypeKey(); + } + return Instruction::Arg(); + } + + std::vector ConvertArgs(const Call& call) { + std::vector ret; + for (size_t i = 0; i < call->args.size(); ++i) { + ret.push_back(ConvertArg(call->args[i])); + } + return ret; + } + + /*! \brief A counter for naming local functions. */ + size_t local_func_counter_ = 0; + /*! \brief Internal ExecBuilder. */ + relax::ExecBuilder builder_; + /*! \brief Total number of virtual registers allocated. */ + size_t registers_num_ = 0; + /*! \brief Map from var to register number. */ + std::unordered_map var_register_map_; + /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ + const Op& alloc_storage_op_ = Op::Get("relax.vm.builtin.alloc_storage"); + const Op& alloc_tensor_op_ = Op::Get("relax.vm.builtin.alloc_tensor"); + const Op& store_shape_op_ = Op::Get("relax.vm.builtin.store_shape"); + const Op& load_shape_op_ = Op::Get("relax.vm.builtin.load_shape"); + const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); + const Op& unique_op_ = Op::Get("relax.unique"); + const Op& print_op_ = Op::Get("relax.print"); + const Op& assert_op_ = Op::Get("relax.assert_op"); + const Op& make_closure_op_ = Op::Get("relax.make_closure"); + const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); +}; + +void VMCodeGen::CodeGen(IRModule rx_mod) { + builder_ = relax::ExecBuilderNode::Create(); + CodeGenVM codegen(builder_.operator->()); + for (auto& p : rx_mod->functions) { + codegen.VisitExpr(p.second); + } +} + +ObjectPtr VMCodeGen::GetExec() { return builder_->Get(); } + +/*! + * \brief Create the Relax VM executable from an IRModule of Relax function(s) and, possibly, a + * kernel library. + * \param mod The IRModule containing Relax function(s). + * \param lib The kernel library. + * \return The constructed Relax VM executable. + */ +Module CodeGen(IRModule mod, Optional lib, Array ext_libs, Target target, + Map params) { + VMCodeGen codegen; + codegen.CodeGen(mod); + ObjectPtr executable = codegen.GetExec(); + if (!lib.defined()) { + lib = codegen::CSourceModuleCreate(";", "", Array{}); + } + std::unordered_map conv_params; + for (const auto& kv : params) { + conv_params[kv.first] = kv.second; + } + Module combined_lib = codegen::CreateMetadataModule( + conv_params, lib.value(), ext_libs, target, + + // TODO(@sunggg): Currently, CRT uses relay-specific executor for uTVM support. + // Before jumping into details, only support cpp runtime for now. + relay::Runtime::Create("cpp"), + relay::Executor::Create("graph"), // TODO(@sunggg): pass arbitrarily executor. CPP runtime + // won't use this anyways. + relay::backend::ExecutorCodegenMetadata()); + executable->Import(combined_lib); + return Module(executable); +} + +TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(CodeGen); + +} // namespace relax_vm +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/vm/codegen_vm.h b/src/relax/backend/vm/codegen_vm.h new file mode 100644 index 0000000000..5791f86888 --- /dev/null +++ b/src/relax/backend/vm/codegen_vm.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/codegen_vm.h + * \brief A codegen to generate VM executable from an IRModule with relax functions. + */ + +#ifndef TVM_RELAX_BACKEND_VM_CODEGEN_VM_H_ +#define TVM_RELAX_BACKEND_VM_CODEGEN_VM_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { +namespace relax_vm { + +using tvm::Target; +using namespace tvm::runtime::relax_vm; +using namespace tvm::runtime; + +class VMCodeGen : public Object { + public: + /*! + * \brief Compile the functions in a Module. + * \param rx_mod Input IRModule that constains relax functions. + */ + void CodeGen(IRModule rx_mod); + /*! + * \brief Get the compiled executable. + * \return The compiled executable. + */ + ObjectPtr GetExec(); + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.VMCodeGen"; + + protected: + /*! \brief Internal executable builder. */ + relax::ExecBuilder builder_; +}; + +} // namespace relax_vm +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_VM_CODEGEN_VM_H_ diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc new file mode 100644 index 0000000000..c5d48d4f14 --- /dev/null +++ b/src/relax/backend/vm/exec_builder.cc @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/exec_builder.cc + */ +#include + +#include + +namespace tvm { +namespace relax { + +using namespace vm; + +TVM_REGISTER_NODE_TYPE(ExecBuilderNode); + +ExecBuilder ExecBuilderNode::Create() { + ExecBuilder ret(make_object()); + ret->exec = make_object(); + return ret; +} + +vm::Index ExecBuilderNode::EmitConstant(TVMRetValue obj) { + vm::Index idx = exec->constants.size(); + exec->constants.push_back(obj); + return vm::Instruction::Arg(vm::Instruction::kConstIdx, idx).data; +} + +void ExecBuilderNode::EmitFunction(std::string func_name, int64_t num_inputs, + Array param_names) { + const auto& m = exec->global_map; + ICHECK(m.find(func_name) == m.end()); + VMFunction vmfunc; + vmfunc.name = func_name; + vmfunc.start_instr = exec->instr_offset.size(); + vmfunc.num_args = num_inputs; + std::vector names; + for (size_t i = 0; i < param_names.size(); ++i) { + names.push_back(param_names[i]); + } + vmfunc.param_names = names; + exec->global_map[func_name] = exec->global_funcs.size(); + exec->global_funcs.push_back(vmfunc); +} + +void ExecBuilderNode::EmitCall(std::string func, std::vector args, RegName dst) { + // store function + if (exec->func2idx.find(func) == exec->func2idx.end()) { + exec->func2idx[func] = exec->func_names.size(); + exec->func_names.push_back(func); + } + Index func_idx = exec->func2idx[func]; + // store instruction + exec->instr_offset.push_back(exec->instr_data.size()); + exec->instr_data.push_back(static_cast(Opcode::Call)); + exec->instr_data.push_back(dst); + exec->instr_data.push_back(func_idx); + exec->instr_data.push_back(args.size()); + // store arguments + std::transform(args.cbegin(), args.cend(), std::back_inserter(exec->instr_data), + [](Instruction::Arg arg) { return arg.data; }); +} + +void ExecBuilderNode::EmitRet(RegName result) { + exec->instr_offset.push_back(exec->instr_data.size()); + exec->instr_data.push_back(static_cast(Opcode::Ret)); + exec->instr_data.push_back(result); +} + +void ExecBuilderNode::EmitGoto(Index pc_offset) { + exec->instr_offset.push_back(exec->instr_data.size()); + exec->instr_data.push_back(static_cast(Opcode::Goto)); + exec->instr_data.push_back(pc_offset); +} + +void ExecBuilderNode::EmitIf(vm::RegName cond, vm::Index false_offset) { + exec->instr_offset.push_back(exec->instr_data.size()); + exec->instr_data.push_back(static_cast(Opcode::If)); + exec->instr_data.push_back(cond); + exec->instr_data.push_back(false_offset); +} + +void ExecBuilderNode::CheckExecutable() { + for (auto it = exec->global_funcs.cbegin(); it != exec->global_funcs.cend(); ++it) { + Index num_inputs = it->num_args; + std::unordered_set dst_registers; + std::unordered_set arg_registers; + size_t start_instr = it->start_instr; + size_t end_instr = exec->instr_offset.size(); + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = exec->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + for (int i = 0; i < instr.num_args; ++i) { + if (instr.args[i].kind() == Instruction::kRegister && + instr.args[i].value() == Instruction::kVMRegister) { + continue; + } + if (instr.args[i].kind() == Instruction::kRegister && + instr.args[i].value() >= num_inputs && + dst_registers.find(instr.args[i].value()) == dst_registers.end()) { + LOG(FATAL) << "register r(" << instr.args[i].value() << ") in VM function \"" + << it->name << "\" is used as input while the number of inputs is only " + << num_inputs << ".\n"; + } + arg_registers.emplace(instr.args[i].value()); + } + if (instr.dst != Instruction::kVoidArg) { + dst_registers.emplace(instr.dst); + } + break; + } + case Opcode::Ret: { + arg_registers.emplace(instr.result); + for (int i = 0; i < num_inputs; i++) { + if (arg_registers.find(i) == arg_registers.end()) { + LOG(WARNING) << "register r(" << i << ") in VM function \"" << it->name + << "\" is unused as input.\n"; + } + } + break; + } + case Opcode::Goto: { + ICHECK_NE(instr.pc_offset, 0); + break; + } + case Opcode::If: { + ICHECK_GT(instr.false_offset, 1); + arg_registers.emplace(instr.cond); + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + } +} + +ObjectPtr ExecBuilderNode::Get() { + this->CheckExecutable(); + this->Formalize(); + return this->exec; +} + +void ExecBuilderNode::Formalize() { + // a pass to formalize user-specified register indexes in the order of use + // and decide the number of registers to allocate for each VMFunction in the Executable + for (auto it = this->exec->global_funcs.begin(); it != this->exec->global_funcs.end(); ++it) { + Index num_inputs = it->num_args; + RegName register_idx = num_inputs; + std::unordered_map register_map; + size_t start_instr = it->start_instr; + size_t end_instr = this->exec->instr_offset.size(); + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = this->exec->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + for (int i = 0; i < instr.num_args; ++i) { + if (instr.args[i].kind() == Instruction::kRegister && + register_map.find(instr.args[i].value()) != register_map.end()) { + this->exec->instr_data[this->exec->instr_offset[idx] + 4 + i] = + register_map[instr.args[i].value()]; + } + } + if (instr.dst != Instruction::kVoidArg && instr.dst >= num_inputs && + register_map.find(instr.dst) == register_map.end()) { + this->exec->instr_data[this->exec->instr_offset[idx] + 1] = register_idx; + register_map[instr.dst] = register_idx++; + } + break; + } + case Opcode::Ret: { + if (register_map.find(instr.result) != register_map.end()) { + this->exec->instr_data[this->exec->instr_offset[idx] + 1] = register_map[instr.result]; + } + break; + } + case Opcode::Goto: { + break; + } + case Opcode::If: { + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + it->register_file_size = register_idx; + } +} + +TVM_REGISTER_GLOBAL("relax.ExecBuilderCreate").set_body_typed(ExecBuilderNode::Create); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitConstant").set_body([](TVMArgs args, TVMRetValue* ret) { + ExecBuilder builder = args[0]; + TVMRetValue rt; + rt = args[1]; + *ret = builder->EmitConstant(rt); +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderFunction") + .set_body_method(&ExecBuilderNode::EmitFunction); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") + .set_body_typed([](ExecBuilder builder, String name, Array args, int64_t dst) { + std::vector args_; + for (size_t i = 0; i < args.size(); ++i) { + args_.push_back(static_cast(args[i]->value)); + } + Instruction::Arg dst_(dst); + CHECK_EQ(dst_.kind(), Instruction::ArgKind::kRegister); + builder->EmitCall(name, args_, dst_.value()); + }); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitRet") + .set_body_method(&ExecBuilderNode::EmitRet); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitGoto") + .set_body_method(&ExecBuilderNode::EmitGoto); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitIf") + .set_body_method(&ExecBuilderNode::EmitIf); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderR").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg(Instruction::kRegister, value).data; +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderImm").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg(Instruction::kImmediate, value).data; +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderC").set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg(Instruction::kConstIdx, value).data; +}); + +TVM_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { + ObjectPtr p_exec = builder->Get(); + return runtime::Module(p_exec); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/vm/vm_memory_lower.cc b/src/relax/backend/vm/vm_memory_lower.cc new file mode 100644 index 0000000000..dfd79143b3 --- /dev/null +++ b/src/relax/backend/vm/vm_memory_lower.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/vm/vm_memory_lower.cc + * \brief Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics. + */ +#include +#include +#include +#include +#include + +#include "../../../relay/transforms/pattern_utils.h" + +namespace tvm { +namespace relax { + +// ================== +// MemLowerMutator +// Lower the relax.builtin.alloc_tensor op to VM builtin functions. +// Example: +// x = relax.builtin.alloc_tensor((m, n), relax.attrs.AllocTensorAttrs) +// --> +// gv0 = relax.call_packed("relax.vm.builtin.alloc_storage", (m * n), +// relax.attrs.VMAllocStorageAttrs) +// gv1 = relax.call_packed("relax.vm.builtin.alloc_tensor", gv0, (m, n), +// relax.attrs.VMAllocTensorAttrs) + +class VMMemLowerMutator : public ExprMutator { + Expr ComputeStorageSize(const Expr& shape, const DataType& dtype) const { + // Question: what if the dtype of tensor_type is unknown? + // Symbolic/static shape case + if (auto* shape_expr = shape.as()) { + PrimExpr num = PrimExpr(dtype.bits()) * PrimExpr(dtype.lanes()); + PrimExpr add = num + 7; + PrimExpr ret = 1; + for (PrimExpr dim : shape_expr->values) { + ret = ret * dim; + } + ret = ret * (add / PrimExpr(8)); + return ShapeExpr({ret}); + } + // Fully dynamic shape case + // will need to dedup with ComputeStorageInRelay when we upstream + Expr prod = relay::Prod(shape, Array(nullptr), false, false); + Expr num = relay::MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes()); + Expr add = relay::Add(num, relay::MakeConstantScalar(DataType::Int(64), 7)); + Expr div = relay::MakeConstantScalar(DataType::Int(64), 8); + Expr ret = relay::Multiply(prod, relay::Divide(add, div)); + return ret; + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call) override { + // post-order mutation + Expr expr = VisitExprPostOrder_(call); + call = expr.as(); + + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + static const Op& vm_alloc_storage_op = Op::Get("relax.vm.builtin.alloc_storage"); + static const Op& vm_alloc_tensor_op = Op::Get("relax.vm.builtin.alloc_tensor"); + + // TODO(@yuchen): memory planning + if (call->op == alloc_tensor_op) { + ShapeExpr output_shape = Downcast(call->args[0]); + auto alloc_attrs = call->attrs.as(); + ICHECK(alloc_attrs != nullptr) << "must be AllocTensorAttrs"; + DataType dtype = alloc_attrs->dtype; + Expr storage_size = ComputeStorageSize(output_shape, dtype); + auto storage_attr = make_object(); + storage_attr->dtype = dtype; + storage_attr->runtime_device_index = alloc_attrs->runtime_device_index; + + Var storage = + builder_->Emit(Call(vm_alloc_storage_op, {storage_size}, Attrs(storage_attr)), "storage"); + auto tensor_attr = make_object(); + tensor_attr->offset = 0; + tensor_attr->dtype = dtype; + Expr shape = call->args[0]; + Var tensor = + builder_->Emit(Call(vm_alloc_tensor_op, {storage, shape}, Attrs(tensor_attr)), "tensor"); + return std::move(tensor); + } + + return GetRef(call); + } +}; + +Expr VMMemLower(const Expr& e) { return VMMemLowerMutator().VisitExpr(e); } + +namespace transform { + +Pass VMMemoryLower() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(VMMemLower(f)); }; + return CreateFunctionPass(pass_func, 0, "VMMemoryLower", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.VMMemoryLower").set_body_typed(VMMemoryLower); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc new file mode 100644 index 0000000000..35550cba00 --- /dev/null +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/vm/vm_shape_lower.cc + * \brief Lower the shape expressions in relax to VM shape heap manipulations and generate related + * TIR functions to do shape calculations. + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class VMShapeLowerMutator : public ExprMutator { + public: + static DataType ShapeDType() { return DataType::Int(64); } + + explicit VMShapeLowerMutator(IRModule mod) : ExprMutator(mod) {} + + IRModule Lower() { + for (auto& p : builder_->GetContextIRModule()->functions) { + Expr func = p.second; + if (func->IsInstance()) { + // prepare mapping and heap var + expr2slot_ = PrepareExpr2Slot(Downcast(func)); + heap_size_ = IntImm(ShapeDType(), expr2slot_.size()); + DynTensorType heap_type(1, ShapeDType()); + shape_heap_ = Var("shape_heap", ShapeExpr({heap_size_}), heap_type); + + // mutate + Function updated_func = Downcast(VisitExpr(func)); + builder_->UpdateFunction(p.first, updated_func); + } + } + return builder_->GetContextIRModule(); + } + + void VisitBinding_(const MatchShapeNode* binding) override { + Expr value = ExprMutator::VisitExpr(binding->value); + + // TODO(@yuchen): match_shape overloaded semantic: value is ShapeType + Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {value}), "sh"); + StoreShape(shape, binding->pattern); + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const ShapeExprNode* node) override { + if (IsConstantShape(GetRef(node))) { + return ExprMutator::VisitExpr_(node); + } + tir::PrimFunc func = CalculateShape(GetRef(node)); + + GlobalVar shape_func_var = builder_->AddFunction(func, "shape_func"); + builder_->Emit(Call(shape_func_var, {shape_heap_}), "_"); + + // construct shape + Array indices; + for (PrimExpr e : node->values) { + indices.push_back(expr2slot_.at(e)); + } + static const Op& load_shape_op = Op::Get("relax.vm.builtin.load_shape"); + auto load_shape_attr = make_object(); + load_shape_attr->indices = indices; + + return builder_->Emit(Call(load_shape_op, {shape_heap_}, Attrs(load_shape_attr)), "sh"); + } + + Expr VisitExpr_(const FunctionNode* node) override { + if (heap_size_->value > 0) { + builder_->BeginBindingBlock(); + builder_->Emit(VarBinding( + shape_heap_, Call(ExternFunc("vm.builtin.alloc_shape_heap"), {ShapeExpr({heap_size_})}))); + + for (Var param : node->params) { + if (param->shape_.operator bool() && param->shape_.value().as()) { + if (auto* param_type = param->checked_type_.as()) { + if (param_type->ndim != 0) { + Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {param}), "sh"); + StoreShape(shape, Downcast(param->shape_.value())->values); + } + } + } + } + } + Expr new_body = this->VisitExpr(node->body); + + Array blocks; + + if (const SeqExprNode* seq = new_body.as()) { + if (heap_size_->value > 0) { + blocks.push_back(builder_->EndBlock()); + } + blocks.insert(blocks.end(), seq->blocks.begin(), seq->blocks.end()); + new_body = seq->body; + } + + // FIXME(@yuchen): Implement vm.builtin.free_shape_heap. + // builder_->Emit(Call(ExternFunc("vm.builtin.free_shape_heap"), {shape_heap_}), "gv"); + new_body = builder_->Normalize(SeqExpr(blocks, new_body)); + + Type ret_type = this->VisitType(node->ret_type); + // should not be necessary to do anything with it + Expr ret_shape = node->ret_shape; + + // Because this pass is the last stage of build, ndim info is no longer needed for tensors. + // The ret_type is weakened to unknown-dimensional DynTensorType. + // TODO(@yuchen): change all tensor types in the function to unknown ndim + if (const DynTensorTypeNode* temp = ret_type.as()) { + ret_type = DynTensorType::CreateUnknownNDim(temp->dtype, Span()); + } + + return Function(node->params, new_body, ret_type, ret_shape, node->attrs); + } + + tir::PrimFunc CalculateShape(ShapeExpr s) { + // TODO(ziheng): avoid generating shape func for known value + tir::Var heap("heap", DataType::Handle()); + Array buffer_shape{heap_size_}; + tir::Buffer buffer = tir::decl_buffer(buffer_shape, ShapeDType(), "H"); + Map buffer_map; + buffer_map.Set(heap, buffer); + + Array seq; + for (PrimExpr e : s->values) { + Map var_mapping = BuildVarMapping(e, buffer); + PrimExpr value = tir::Substitute(e, var_mapping); + // cast value to shape heap dtype + if (value.dtype() != ShapeDType()) value = tir::Cast(ShapeDType(), value); + Integer idx = expr2slot_.at(e); + seq.push_back(tir::BufferStore(buffer, value, {idx})); + } + tir::Stmt body = tir::SeqStmt(seq); + Array params{heap}; + Type ret_type = VoidType(); + + return tir::PrimFunc(params, body, ret_type, buffer_map); + } + + Map BuildVarMapping(PrimExpr expr, tir::Buffer buffer) { + Map ret; + auto func = [&](const ObjectRef& e) { + if (e->IsInstance()) { + PrimExpr prim_e = Downcast(e); + tir::BufferLoad load(buffer, {expr2slot_.at(prim_e)}); + ret.Set(Downcast(e), load); + } + }; + tir::PostOrderVisit(expr, func); + return ret; + } + + Map PrepareExpr2Slot(Function expr) const { + int cnt = 0; + bool is_dyn_shape = false; + Map ret; + auto func = [&](const Expr& e) { + if (e->IsInstance()) { + ShapeExpr shape = Downcast(e); + for (auto prim_e : shape->values) { + if (!prim_e->IsInstance()) { + is_dyn_shape = true; + } + if (ret.count(prim_e) == 0) { + ret.Set(prim_e, cnt++); + } + } + } + }; + PostOrderVisit(expr, func); + + // Avoid allocating shape heap and do shape computation for static-shape program + if (!is_dyn_shape) { + ret.clear(); + } + return ret; + } + + /*! \brief Store symbolic shape into indices of the VM shape heap. */ + void StoreShape(Expr shape, Array pattern) { + static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape"); + auto store_shape_attr = make_object(); + + Array indices; + for (size_t i = 0; i < pattern.size(); ++i) { + Integer idx = expr2slot_.at(pattern[i]); + indices.push_back(idx); + } + store_shape_attr->indices = indices; + builder_->Emit(Call(store_shape_op, {shape, shape_heap_}, Attrs(store_shape_attr)), "gv"); + } + + bool IsConstantShape(ShapeExpr shape) const { + for (PrimExpr e : shape->values) { + if (!e->IsInstance()) { + return false; + } + } + return true; + } + + private: + // function-wise members + IntImm heap_size_; + Var shape_heap_; + Map expr2slot_; +}; + +namespace transform { + +Pass VMShapeLower() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return VMShapeLowerMutator(mod).Lower(); }; + return CreateModulePass(pass_func, 0, "VMShapeLower", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed(VMShapeLower); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc new file mode 100644 index 0000000000..a6d6b27d81 --- /dev/null +++ b/src/relax/ir/binding_rewrite.cc @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/binding_rewrite.cc + * \brief Implementation of binding rewriters. + */ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(DataflowBlockRewriteNode); +DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) { + auto n = make_object(); + n->dfb_ = dfb; + n->root_fn_ = root_fn; + n->original_fn_ptr_ = root_fn.get(); + auto p = FunctionUseDef(root_fn); + n->to_users_ = std::move(p.first); + n->fn_outputs_ = std::move(p.second); + n->name_table_ = NameTable(n->to_users_.begin(), n->to_users_.end(), + [](const auto& p) { return p.first->name_hint(); }); + + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowBlockRewrite") + .set_body_typed([](DataflowBlock dfb, Function root_fn) { + return DataflowBlockRewrite(dfb, root_fn); + }); + +void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { + class ReplaceAllUsePass : public ExprMutator { + Var old_var, new_var; + const DataflowBlockNode* const to_catch; + + public: + const DataflowBlockNode* caught = nullptr; + + ReplaceAllUsePass(Var old_var, Var new_var, const DataflowBlockNode* to_catch) + : old_var(old_var), new_var(new_var), to_catch(to_catch) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const VarNode* op) override { + return (op == old_var.get()) ? new_var : GetRef(op); + } + + Expr VisitExpr_(const DataflowVarNode* op) override { + return (op == old_var.get()) ? new_var : GetRef(op); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + BindingBlock res = ExprMutator::VisitBindingBlock_(op); + if (op == to_catch) caught = static_cast(res.get()); + return res; + } + }; + + ICHECK(to_users_.find(old_var) != to_users_.end()) << "Cannot find " << old_var; + ICHECK(to_users_.find(new_var) != to_users_.end()) << "Cannot find " << new_var; + + // replace uses inside the DataflowBlock. + ReplaceAllUsePass replacer(old_var, new_var, dfb_.get()); + root_fn_ = Downcast(replacer.VisitExpr_(root_fn_.get())); + dfb_ = GetRef(replacer.caught); + + // update udchain + // old_var -> old_var users | changed to {} + // new_var -> {?} | changed to old_var users + for (Var user : to_users_[old_var]) { + auto new_var_uses = to_users_[new_var]; + if (new_var_uses.end() == std::find(new_var_uses.begin(), new_var_uses.end(), user)) { + new_var_uses.push_back(user); + } + } + + to_users_.Set(old_var, {}); + + auto it_old_output = std::find(fn_outputs_.begin(), fn_outputs_.end(), old_var); + if (it_old_output != fn_outputs_.end()) { + fn_outputs_.Set(std::distance(fn_outputs_.begin(), it_old_output), new_var); + } +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses") + .set_body_typed([](DataflowBlockRewrite rwt, Var old_var, Var new_var) { + rwt->ReplaceAllUses(old_var, new_var); + }); + +class UpdateDFB : public ExprMutator { + private: + DataflowBlock old_dfb, new_dfb; + + public: + UpdateDFB(DataflowBlock old_dfb, DataflowBlock new_dfb) + : old_dfb(std::move(old_dfb)), new_dfb(std::move(new_dfb)) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override { + return old_dfb.get() == op ? new_dfb : old_dfb; + } +}; + +void DataflowBlockRewriteNode::Add(Binding binding) { + auto p = [binding] { + if (auto vb = binding.as()) { + return std::make_pair(vb->var, vb->value); + } else if (auto ms = binding.as()) { + return std::make_pair(ms->var, ms->value); + } + LOG(FATAL) << "Unsupported binding type"; + return std::make_pair(Var{}, Expr{}); + }(); + + Var var = p.first; + Expr val = p.second; + + ICHECK(0 == to_users_.count(var)) << var << " has been defined so cannot be added."; + + // Add this VarBinding statement after the definition of uses. + std::set used_vars = [val] { + class UsedVars : public ExprVisitor { + public: + std::set used_vars; + void VisitExpr_(const VarNode* op) override { used_vars.insert(op); } + void VisitExpr_(const DataflowVarNode* op) override { used_vars.insert(op); } + } uvar{}; + uvar.VisitExpr(val); + return std::move(uvar.used_vars); + }(); + + size_t line_last_req_def = 0; + for (size_t i = 0; i < dfb_.value()->bindings.size(); ++i) { + auto line = dfb_.value()->bindings[i]; + if (auto varbind = line.as()) { + if (used_vars.find(varbind->var.get()) != used_vars.cend()) line_last_req_def = i; + } else if (auto mshape = line.as()) { + if (used_vars.find(mshape->var.get()) != used_vars.cend()) line_last_req_def = i; + } + } + + auto old_dfb = dfb_.value(); + + dfb_ = [old_dfb, binding, line_last_req_def, this] { + auto new_dfb = dfb_.value(); + new_dfb.CopyOnWrite()->bindings.insert(dfb_.value()->bindings.begin() + 1 + line_last_req_def, + binding); + return new_dfb; + }(); + + auto updater = UpdateDFB(old_dfb, dfb_.value()); + root_fn_ = Downcast(updater.VisitExpr_(root_fn_.get())); + + for (const VarNode* v : used_vars) to_users_.Get(GetRef(v)).value().push_back(var); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding") + .set_body_typed([](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); }); + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add") + .set_body_typed([](DataflowBlockRewrite rwt, Expr expr, Optional name, bool is_dfvar) { + if (name.get()) { + rwt->Add(name.value(), expr, is_dfvar); + } else { + rwt->Add(expr, is_dfvar); + } + }); + +class RemoveUnusedVars : public ExprMutator { + public: + std::set unused_vars; + Optional caught_rewrite = NullOpt; + + RemoveUnusedVars(Map> users, Array fn_outputs) + : unused_vars([&] { + std::vector unused; + + // iterative dataflow algorithm. + size_t prev_size; + do { + prev_size = unused.size(); + + for (const auto& kv : users) { + // var -> [users...] + // var is unused iff + // user -> empty + // var is not output var + if (kv.second.empty() && // kv.first is not used by fn outputs. + fn_outputs.end() == std::find(fn_outputs.begin(), fn_outputs.end(), kv.first)) { + unused.push_back(kv.first); + } + } + + for (size_t i = prev_size; i < unused.size(); ++i) { + users.erase(unused[i]); + // remove def site. + for (auto kv : users) { // remove use site. + auto it = std::find(kv.second.begin(), kv.second.end(), unused[i]); + if (it != kv.second.end()) { + kv.second.erase(it); + users.Set(kv.first, std::move(kv.second)); + } + } + } + } while (prev_size != unused.size()); // changed? => continue. + + return std::set(unused.begin(), unused.end()); + }()) {} + + RemoveUnusedVars(std::pair>, Array> users_and_outputs) + : RemoveUnusedVars(std::move(users_and_outputs.first), std::move(users_and_outputs.second)) {} + RemoveUnusedVars(Function fn) : RemoveUnusedVars(FunctionUseDef(fn)) {} + RemoveUnusedVars(std::set unused_vars) : unused_vars(std::move(unused_vars)) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) { + auto prev_dfb = GetRef(block); + builder_->BeginDataflowBlock(); + for (Binding binding : block->bindings) { + if (const auto* node = binding.as()) { + if (!unused_vars.count(node->var)) VisitBinding_(node); + } else if (const auto* node = binding.as()) { + if (!unused_vars.count(node->var)) VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + } + auto new_dfb = builder_->EndBlock(); + if (caught_rewrite == prev_dfb) caught_rewrite = Downcast(new_dfb); + return std::move(new_dfb); + } +}; + +void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) { + // first need to check if this var is used. + if (0 == to_users_.count(unused)) { // no def. + if (allow_undef) return; + LOG(FATAL) << unused << " undefined. Set allow_undef=True to allow 'removing' undefined var"; + } + + ICHECK(to_users_[unused].empty()) + << unused << " is used by " << to_users_[unused].size() << " vars"; + + auto old_dfb = dfb_.value(); + + RemoveUnusedVars remover({unused}); + dfb_ = Downcast(remover.VisitBindingBlock_(old_dfb.get())); + + auto updater = UpdateDFB(old_dfb, dfb_.value()); + root_fn_ = Downcast(updater.VisitExpr_(root_fn_.get())); + + to_users_.erase(unused); // update use-def chain. +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused") + .set_body_typed([](DataflowBlockRewrite rwt, Var unused, bool allow_undef) { + rwt->RemoveUnused(unused, allow_undef); + }); + +void DataflowBlockRewriteNode::RemoveAllUnused() { + RemoveUnusedVars remover(to_users_, fn_outputs_); + remover.caught_rewrite = dfb_.value(); + + // this could also clean unused variables in other DataflowBlock. + root_fn_ = Downcast(remover.VisitExpr_(root_fn_.get())); + + // DataflowBlock could be None. + dfb_ = remover.caught_rewrite.value(); + + // clean up use-def chain. + for (const auto& unused : remover.unused_vars) to_users_.erase(unused); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused") + .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); }); + +Function RemoveAllUnused(Function fn) { + RemoveUnusedVars remover(fn); + return Downcast(remover.VisitExpr_(fn.get())); +} + +TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused); + +IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { + BlockBuilder builder = BlockBuilder::Create(irmod); + + for (auto& p : irmod->functions) { + if (original_fn_ptr_ == p.second.get()) { + builder->UpdateFunction(p.first, root_fn_.value()); + break; + } + } + + return builder->GetContextIRModule(); +} + +TVM_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule") + .set_body_typed([](DataflowBlockRewrite rwt, IRModule irmod) { + return rwt->MutateIRModule(irmod); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc new file mode 100644 index 0000000000..5ceca53a11 --- /dev/null +++ b/src/relax/ir/block_builder.cc @@ -0,0 +1,821 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/block_builder.cc + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +// ================================ +// BlockBuilderNode::ExprNormalizer + +// TODO(@altanh): more test cases to cover different visits +class BlockBuilderNode::ExprNormalizer : public ExprFunctor { + public: + ExprNormalizer(BlockBuilderNode* builder) : builder_(builder) {} + +#define RELAX_EXPR_NORMALIZER_LEAF(OP) \ + Expr VisitExpr_(const OP* op) final { return GetRef(op); } + + RELAX_EXPR_NORMALIZER_LEAF(VarNode); + RELAX_EXPR_NORMALIZER_LEAF(DataflowVarNode); + RELAX_EXPR_NORMALIZER_LEAF(RuntimeDepShapeNode); + RELAX_EXPR_NORMALIZER_LEAF(ExternFuncNode); + RELAX_EXPR_NORMALIZER_LEAF(GlobalVarNode); + RELAX_EXPR_NORMALIZER_LEAF(OpNode); + RELAX_EXPR_NORMALIZER_LEAF(ShapeExprNode); + + // TODO(@altanh): CopyOnWrite + + Expr VisitExpr(const Expr& expr) { + // TODO(relax-team): generalize prim_func support + if (expr->IsInstance()) { + return expr; + } + Optional post = expr_memo_.Get(expr); + if (post) { + ICHECK(post.as()) << "memoized expressions should map to variables"; + return post.value(); + } + return ExprFunctor::VisitExpr(expr); + } + + Expr VisitExpr_(const TupleNode* op) final { + bool unchanged = true; + Array new_fields; + for (const Expr& field : op->fields) { + Expr new_field = this->Bind(field); + new_fields.push_back(new_field); + unchanged &= new_field.same_as(field); + } + Tuple tuple = unchanged ? GetRef(op) : Tuple(new_fields); + + // only do shape/type inference if the Tuple does not have shape/type + if (tuple->shape_ && tuple->checked_type_.defined()) { + return tuple; + } + + // Tuple's shape can be null, when a tuple consists of all DynTensorType, it has a shape + if (!tuple->shape_) { + UpdateShape(tuple, GetTupleShape(tuple)); + } + + // Tuple's checked_type must not be null + if (!tuple->checked_type_.defined()) { + Array tuple_type; + for (Expr field : tuple->fields) { + ICHECK(field->checked_type_.defined()) + << "The checked_type_ of the field " << field << " of Tuple has not propagated."; + tuple_type.push_back(field->checked_type_); + } + UpdateType(tuple, TupleType(tuple_type)); + } + return tuple; + } + + Expr VisitExpr_(const FunctionNode* op) final { + Expr new_body = this->VisitWithNewScope(op->body); + Function func; + if (new_body.same_as(op->body)) { + func = GetRef(op); + } else { + func = Function(op->params, new_body, op->ret_type, op->ret_shape, op->attrs); + } + + // NOTE: the shape_ of Function is left as null for now, to be consitent with + // Skip the deduction of function type of a function + // as the function type needs to be annotated in certain cases(mutual function call) + // TODO(tvm-team) deduce function's type in construction time. + return func; + } + + Expr VisitExpr_(const CallNode* op) final { + Expr new_op = this->VisitExpr(op->op); + bool unchanged = new_op.same_as(op->op); + + Array new_args; + if (!op->args.empty()) { + for (const Expr& arg : op->args) { + Expr new_arg = this->Bind(arg); + new_args.push_back(new_arg); + unchanged &= new_arg.same_as(arg); + } + } + + Call call; + if (unchanged) { + call = GetRef(op); + } else { + call = Call(new_op, new_args, op->attrs, op->type_args); + } + + // only do shape/type inference if the Call does not have shape/type + if (call->shape_ && call->checked_type_.defined()) { + return call; + } + + if (!call->shape_) { + // shape inference + auto inferred_shape = + InferShape(call, this->builder_->diag_ctx_, this->builder_->context_mod_); + if (inferred_shape) { + UpdateShape(call, inferred_shape.value()); + } + } + + if (!call->checked_type_.defined()) { + // type inference + auto inferred_type = InferType(call, this->builder_->diag_ctx_, this->builder_->context_mod_); + UpdateType(call, inferred_type); + } + return call; + } + + Expr VisitExpr_(const SeqExprNode* op) final { + bool unchanged = true; + Array new_blocks; + for (const BindingBlock& block : op->blocks) { + // TODO(@altanh): we could merge sequential non-dataflow BindingBlocks here + BindingBlock new_block = this->VisitBindingBlock(block); + new_blocks.push_back(new_block); + unchanged &= new_block.same_as(block); + } + + builder_->BeginBindingBlock(); + Expr new_body = this->VisitExpr(op->body); + unchanged &= new_body.same_as(op->body); + BindingBlock prologue = builder_->EndBlock(); + + // TODO(@altanh, @yuchen): normalize nested SeqExprs and BindingBlocks + + if (!prologue->bindings.empty()) { + new_blocks.push_back(prologue); + unchanged = false; + } + + SeqExpr seq_expr; + if (unchanged) { + seq_expr = GetRef(op); + } + seq_expr = SeqExpr(new_blocks, new_body); + + // only do shape/type inference if the SeqExpr does not have shape/type + if (seq_expr->shape_ && seq_expr->checked_type_.defined()) { + return seq_expr; + } + + if (!seq_expr->shape_ && seq_expr->body->shape_) { + UpdateShape(seq_expr, seq_expr->body->shape_); + } + + if (!seq_expr->checked_type_.defined() && seq_expr->body->checked_type_.defined()) { + UpdateType(seq_expr, seq_expr->body->checked_type_); + } + return seq_expr; + } + + Expr VisitExpr_(const ConstantNode* op) final { + Constant constant = GetRef(op); + + // only do shape/type inference if the Constant does not have shape/type + if (constant->shape_ && constant->checked_type_.defined()) { + return constant; + } + + auto shape_tuple = constant->data.Shape(); + if (!constant->shape_) { + Array values; + for (size_t dim = 0; dim < shape_tuple.size(); dim++) { + values.push_back(IntImm(DataType::Int(64), shape_tuple[dim])); + } + UpdateShape(constant, relax::ShapeExpr(values)); + } + + if (!constant->checked_type_.defined()) { + DataType dtype = constant->data.DataType(); + Type type = relax::DynTensorType(shape_tuple.size(), dtype); + UpdateType(constant, type); + } + return constant; + } + + Expr VisitExpr_(const IfNode* op) final { + Expr new_cond = this->VisitExpr(op->cond); + Expr new_true = this->VisitWithNewScope(op->true_branch); + Expr new_false = this->VisitWithNewScope(op->false_branch); + if (new_cond.same_as(op->cond) && new_true.same_as(op->true_branch) && + new_false.same_as(op->false_branch)) { + return GetRef(op); + } + // TODO(relax-team): fix type/shape deduction for if node. + return If(new_cond, new_true, new_false); + } + + Expr VisitExpr_(const TupleGetItemNode* op) final { + Expr new_tuple = this->VisitExpr(op->tuple); + TupleGetItem node; + if (new_tuple.same_as(op->tuple)) { + node = GetRef(op); + } + node = TupleGetItem(new_tuple, op->index); + + // only do shape/type inference if the TupleGetItem does not have shape/type + if (node->shape_ && node->checked_type_.defined()) { + return node; + } + + if (!node->checked_type_.defined()) { + const TupleTypeNode* tuple_type = node->tuple->checked_type_.as(); + ICHECK(tuple_type) << "The checked_type_ of Tuple must be TupleTypeNode."; + UpdateType(node, tuple_type->fields[node->index]); + } + + if (!node->shape_ && node->tuple->shape_) { + // TODO(@prakalp, @yuchen): assign the shape_ to RuntimeDepShape when we cannot obtain the + // field + if (const TupleNode* shape = node->tuple->shape_.as()) { + UpdateShape(node, shape->fields[node->index]); + } + } + + return node; + } + + Binding VisitBinding(const Binding& binding) { + if (binding.as()) { + return this->VisitVarBinding(Downcast(binding)); + } else { + ICHECK(binding.as()) << "expected VarBinding or MatchShape, got " << binding; + return this->VisitMatchShape(Downcast(binding)); + } + } + + VarBinding VisitVarBinding(const VarBinding& binding) { + Expr new_value = this->VisitExpr(binding->value); + if (new_value.same_as(binding->value) || new_value.same_as(binding->var)) { + // if new_value = binding->var, then we have found an ANF binding site, so just return it + return binding; + } + return VarBinding(binding->var, new_value); + } + + MatchShape VisitMatchShape(const MatchShape& binding) { + Expr new_value = this->VisitExpr(binding->value); + if (new_value.same_as(binding->value)) { + return binding; + } + return MatchShape(new_value, binding->pattern, binding->var); + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) { + if (block.as()) { + builder_->BeginDataflowBlock(); + } else { + builder_->BeginBindingBlock(); + } + + bool unchanged = true; + for (const Binding& binding : block->bindings) { + Binding new_binding = this->VisitBinding(binding); + unchanged &= new_binding.same_as(binding); + if (new_binding.as()) { + VarBinding var_binding = Downcast(new_binding); + if (builder_->CurrentBlockIsDataFlow() && !var_binding->var.as()) { + builder_->EmitOutput(var_binding); + } else { + builder_->Emit(var_binding); + } + } else { + ICHECK(new_binding.as()); + builder_->EmitMatchShape(Downcast(new_binding)); + } + } + BindingBlock new_block = builder_->EndBlock(); + unchanged &= new_block->bindings.size() == block->bindings.size(); + if (unchanged) { + return block; + } + return new_block; + } + + private: + /*! + * \brief Memoization map for expressions using Id for equality of variables. + */ + class ExprMemo { + public: + Optional Get(const Expr& expr) { + if (const VarNode* var = expr.as()) { + auto it = var_memo_.find(var->vid); + if (it != var_memo_.end()) { + return it->second; + } + } else { + auto it = expr_memo_.find(expr); + if (it != expr_memo_.end()) { + return it->second; + } + } + return NullOpt; + } + + void Set(const Expr& pre, const Expr& post) { + if (const VarNode* var = pre.as()) { + var_memo_[var->vid] = post; + } else { + expr_memo_[pre] = post; + } + } + + private: + std::unordered_map var_memo_; + std::unordered_map expr_memo_; + }; + + // Helper function to check if a ShapeExpr is constant shape or tuple of constant shape + bool IsConstantShapes(const Expr& shape) const { + if (const auto* shape_expr = shape.as()) { + return std::all_of(shape_expr->values.begin(), shape_expr->values.end(), + [](const PrimExpr& e) { return e->IsInstance(); }); + } else if (const auto* shape_tuple = shape.as()) { + return std::all_of(shape_tuple->fields.begin(), shape_tuple->fields.end(), + [&](const Expr& e) { return IsConstantShapes(e); }); + } else { + return false; + } + } + + // Helper function to infer the shape of a Call. + Optional InferShape(const Call& call, DiagnosticContext diag_ctx, IRModule ctx_mod) { + if (call->op.as()) { + // call_packed: return RuntimeDepShape + return RuntimeDepShape(); + } else if (call->op.as()) { + // primitive op: look up FInferShape attribute + Op op = Downcast(call->op); + if (op_map_infer_shape_.count(op)) { + return op_map_infer_shape_[op](call, diag_ctx); + } + } else if (const auto* gv = call->op.as()) { + // global function: find the function's shape_ + auto it_func = ctx_mod->functions.find(GetRef(gv)); + + if (it_func != ctx_mod->functions.end()) { + if (const auto* func = (*it_func).second.as()) { + Expr func_shape = Downcast(func->body->shape_); + if (IsConstantShapes(func_shape)) { + // Case 1. Nested tuples of constant shapes + return func_shape; + } else { + // TODO(@yuchen): add deducer for other cases + return RuntimeDepShape(); + } + } + } + // TODO(@yuchen): add this check after normalization in parser + // else { + // LOG(FATAL) << "ValueError: Cannot find function " << gv->name_hint + // << " in the context IRModule."; + // } + } else if (const auto* var = call->op.as()) { + if (var->shape_) { + return Downcast(var->shape_.value()); + } + Optional val = builder_->LookupBinding(GetRef(var)); + if (const auto* func_node = val.value().as()) { + Function func = GetRef(func_node); + if (func->ret_type.as()) { + Expr func_shape = Downcast(func_node->body->shape_); + if (IsConstantShapes(func_shape)) { + return func_shape; + } else { + // TODO(@yuchen, @yongwww): add deducer for other cases + return RuntimeDepShape(); + } + } + } + } else { + LOG(FATAL) << "ValueError: Failed to do shape inference for " << call->op->GetTypeKey(); + } + + return NullOpt; + } + + // Helper function to infer the type of a Call. + Type InferType(const Call& call, DiagnosticContext diag_ctx, IRModule ctx_mod) { + // call_packed: return type_args of the call node + // TODO(@yuchen): add type annotation for ExternFuncNode + if (call->op.as()) { + if (call->type_args.defined()) { + if (call->type_args.size() == 0) { + return ObjectType(); + } else if (call->type_args.size() == 1) { + return call->type_args.front(); + } else { + return TupleType(call->type_args); + } + } + } else if (call->op.as()) { + // primitive op: look up FInferType attribute + Op op = Downcast(call->op); + if (op_map_infer_type_.count(op)) { + return op_map_infer_type_[op](call, diag_ctx); + } else { + LOG(FATAL) << "ValueError: Cannot find the FInferType attribute registered to op: " + << op->name; + } + } else if (const auto* gv = call->op.as()) { + // global function: find the function's checked_type_ + auto it_func = ctx_mod->functions.find(GetRef(gv)); + if (it_func != ctx_mod->functions.end()) { + if (const auto* func = (*it_func).second.as()) { + return func->ret_type; + } + // TODO(@yuchen): add this check after normalization in parser + // else { + // LOG(FATAL) << "ValueError: Cannot find function " << gv->name_hint + // << " in the context IRModule."; + // } + } + } else if (auto* var = call->op.as()) { + // TODO(@yongwww, yuchen): handle the infer with more specific cases + Optional val = builder_->LookupBinding(GetRef(var)); + if (const auto* func_node = val.value().as()) { + return func_node->ret_type; + } + if (auto* ft_node = var->checked_type_.as()) { + return ft_node->ret_type; + } + } else { + // TODO(@yuchen): call to local var/function support + LOG(FATAL) << "ValueError: Failed to do type inference for " << call->op->GetTypeKey(); + } + + return Type(); + } + + // Helper function to get the shape of a Tuple based on its fields + Optional GetTupleShape(const Tuple& tuple) { + Array tuple_shape; + for (Expr field : tuple->fields) { + if (field->shape_) { + tuple_shape.push_back(Downcast(field->shape_.value())); + } else { + break; + } + } + if (tuple_shape.size() == tuple->fields.size()) { + return Tuple(tuple_shape); + } + return NullOpt; + } + + static bool IsLeaf(const Expr& expr) { + // NB: tuples are treated as leaf nodes for ergonomics + // TODO(@altanh, @yuchen): remove TupleNode from leaf + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as() || + expr.as() || expr.as() || expr.as(); + } + + Expr VisitWithNewScope(const Expr& expr) { + builder_->BeginBindingBlock(); + Expr post = this->VisitExpr(expr); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + post = SeqExpr({prologue}, post); + } + return post; + } + + Expr Bind(const Expr& expr) { + Expr post = this->VisitExpr(expr); + if (!IsLeaf(post)) { + post = builder_->Emit(post); + expr_memo_.Set(expr, post); + } + return post; + } + + /*! \brief BlockBuilder used for emitting intermediate variables. */ + BlockBuilderNode* builder_; + + /*! \brief Memoization table for mapping expressions to their ANF variables. */ + ExprMemo expr_memo_; + + /*! \brief Operator to shape inference map. */ + tvm::OpAttrMap op_map_infer_shape_ = Op::GetAttrMap("FInferShape"); + + /*! \brief Operator to type inference map. */ + tvm::OpAttrMap op_map_infer_type_ = Op::GetAttrMap("FInferType"); +}; + +// ================ +// BlockBuilderNode + +TVM_REGISTER_NODE_TYPE(BlockBuilderNode); + +BlockBuilderNode::BlockBuilderNode() { + name_table_ = std::make_unique(); + normalizer_ = std::make_unique(this); +} + +BlockBuilderNode::~BlockBuilderNode() { + if (!block_stack_.empty()) { + LOG(WARNING) << "BlockBuilder destroyed with remaining blocks!"; + } +} + +void BlockBuilderNode::BeginDataflowBlock() { this->block_stack_.push({{}, true}); } + +void BlockBuilderNode::BeginBindingBlock() { this->block_stack_.push({{}, false}); } + +BindingBlock BlockBuilderNode::EndBlock() { + BlockFrame* cur_frame = CurrentFrame(); + BindingBlock ret = cur_frame->is_dataflow ? DataflowBlock(cur_frame->bindings) + : BindingBlock(cur_frame->bindings); + block_stack_.pop(); + return ret; +} + +Var BlockBuilderNode::Emit(const Expr& expr, std::string name_hint) { + return Emit(expr, CurrentFrame()->is_dataflow, name_hint); +} + +Var BlockBuilderNode::Emit(const Expr& expr, bool is_dataflow, std::string name_hint) { + BlockFrame* cur_frame = CurrentFrame(); + Expr normalized = Normalize(expr); + + if (name_hint.empty()) { + name_hint = is_dataflow ? "lv" : "gv"; + } + Id vid = Id(name_table_->GetUniqueName(name_hint)); + Var var = is_dataflow ? DataflowVar(vid, NullOpt, NullOpt) : Var(vid, NullOpt, NullOpt); + UpdateType(var, normalized->checked_type_); + UpdateShape(var, normalized->shape_); + + cur_frame->bindings.push_back(VarBinding(var, expr)); + binding_table_[var->vid] = expr; + + return var; +} + +Var BlockBuilderNode::Emit(const VarBinding& binding) { + BlockFrame* cur_frame = CurrentFrame(); + if (cur_frame->is_dataflow) { + ICHECK(binding->var.as()) + << "Emit can only be used for local bindings in a dataflow block, use EmitOutput for " + "output bindings instead"; + } + cur_frame->bindings.push_back(binding); + binding_table_[binding->var->vid] = binding->value; + return binding->var; +} + +Var BlockBuilderNode::EmitMatchShape(const Expr& value, const Array& pattern, + std::string name_hint) { + BlockFrame* cur_frame = CurrentFrame(); + + if (name_hint.empty()) { + name_hint = cur_frame->is_dataflow ? "lv" : "gv"; + } + Id vid = Id(name_table_->GetUniqueName(name_hint)); + Var var = + cur_frame->is_dataflow ? DataflowVar(vid, NullOpt, NullOpt) : Var(vid, NullOpt, NullOpt); + + if (value->checked_type().as()) { + UpdateType(var, ShapeType()); + } else if (const DynTensorTypeNode* tty = value->checked_type().as()) { + ShapeExpr shape = ShapeExpr(pattern); + UpdateShape(var, shape); + DataType dtype = tty->dtype; + UpdateType(var, DynTensorType(pattern.size(), dtype)); + } else { + this->diag_ctx_.EmitFatal( + Diagnostic::Error(value->span) + << "The value passed to EmitMatchShape must be of DynTensorType or ShapeType."); + } + + MatchShape match_shape = MatchShape(value, pattern, var); + cur_frame->bindings.push_back(match_shape); + binding_table_[var->vid] = value; + return var; +} + +Var BlockBuilderNode::EmitMatchShape(const MatchShape& binding) { + BlockFrame* cur_frame = CurrentFrame(); + if (binding->var.defined()) { + ICHECK(cur_frame->is_dataflow || !binding->var.as()) + << "cannot emit dataflow vars outside a dataflow block: " << binding->var->name_hint(); + binding_table_[binding->var->vid] = binding->value; + } + cur_frame->bindings.push_back(binding); + // TODO(@altanh, @yuchen): what value should we bind? Consider + // y = add(x, x) + // z = match_shape(y, (n, m)) + // We would like pass writers to match "z" with the "add" node but with extra shape info. + // Maybe this logic could be deferred to a DFPattern-style rewriter? + return binding->var; +} + +Var BlockBuilderNode::EmitOutput(const Expr& output, std::string name_hint) { + BlockFrame* cur_frame = CurrentFrame(); + + ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; + + return Emit(output, false, name_hint); +} + +Var BlockBuilderNode::EmitOutput(const VarBinding& binding) { + BlockFrame* cur_frame = CurrentFrame(); + + ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; + ICHECK(!binding->var.as()) << "EmitOutput can only emit Var bindings."; + + cur_frame->bindings.push_back(binding); + binding_table_[binding->var->vid] = binding->value; + return binding->var; +} + +Optional BlockBuilderNode::LookupBinding(const Var& var) { + auto it = binding_table_.find(var->vid); + if (it == binding_table_.end()) return NullOpt; + return it->second; +} + +bool BlockBuilderNode::CanProveShapeEqual(const Expr& lhs, const Expr& rhs) { + if (lhs == rhs) { + return true; + } + const auto* lhs_shape = lhs.as(); + const auto* rhs_shape = rhs.as(); + if (lhs_shape && rhs_shape) { + size_t lhs_ndim = lhs_shape->values.size(); + size_t rhs_ndim = rhs_shape->values.size(); + if (lhs_ndim != rhs_ndim) { + return false; + } + arith::Analyzer analyzer; + for (size_t i = 0; i < lhs_ndim; ++i) { + PrimExpr lhs_dim = lhs_shape->values[i]; + PrimExpr rhs_dim = rhs_shape->values[i]; + if (!analyzer.CanProveEqual(lhs_dim, rhs_dim)) { + return false; + } + } + return true; + } + return false; +} + +// TODO(@altanh, @yuchen): need an internal Emit_ that doesn't call normalize +Expr BlockBuilderNode::Normalize(const Expr& expr) { + // TODO(@altanh): fast path + Expr normalized = normalizer_->VisitExpr(expr); + return normalized; +} + +BlockBuilderNode::BlockFrame* BlockBuilderNode::CurrentFrame() { + ICHECK(!block_stack_.empty()) << "no block is being built"; + return &block_stack_.top(); +} + +NameTable* BlockBuilderNode::name_table() { return name_table_.get(); } + +GlobalVar BlockBuilderNode::AddFunction(const BaseFunc& func, const String& func_name_hint) { + auto it = func_map_.find(func); + if (it == func_map_.end()) { + context_mod_.CopyOnWrite(); + String func_name = name_table_->GetUniqueName(func_name_hint); + GlobalVar gvar = GlobalVar(func_name); + if (const tir::PrimFuncNode* prim_func = func.as()) { + tir::PrimFunc fn = GetRef(prim_func); + fn = WithAttr(std::move(fn), "global_symbol", func_name); + context_mod_->Add(gvar, fn); + } else { + context_mod_->Add(gvar, func); + } + func_map_.emplace(func, gvar); + return gvar; + } else { + return it->second; + } +} + +void BlockBuilderNode::UpdateFunction(const GlobalVar& gv, BaseFunc updated_function) { + context_mod_.CopyOnWrite(); + context_mod_->Update(gv, updated_function); + func_map_[updated_function] = gv; +} + +IRModule BlockBuilderNode::GetContextIRModule() const { return context_mod_; } + +BlockBuilder BlockBuilder::Create(Optional mod) { + ObjectPtr n = make_object(); + if (mod) { + n->context_mod_ = mod.value(); + } + BlockBuilder block_builder(n); + for (const auto& kv : n->context_mod_->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + block_builder->func_map_.emplace(func, gv); + } + return block_builder; +} + +TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed([](Optional mod) { + return BlockBuilder::Create(mod); +}); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginDataflowBlock") + .set_body_method(&BlockBuilderNode::BeginDataflowBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginBindingBlock") + .set_body_method(&BlockBuilderNode::BeginBindingBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEndBlock") + .set_body_method(&BlockBuilderNode::EndBlock); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize") + .set_body_method(&BlockBuilderNode::Normalize); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit").set_body_typed([](BlockBuilder builder, Expr expr) { + return builder->Emit(expr); +}); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitVarBinding") + .set_body_typed([](BlockBuilder builder, VarBinding binding) { + return builder->Emit(binding); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchShape") + .set_body_typed([](BlockBuilder builder, Expr value, Array pattern) { + return builder->EmitMatchShape(value, pattern); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchShapeBinding") + .set_body_typed([](BlockBuilder builder, MatchShape binding) { + return builder->EmitMatchShape(binding); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") + .set_body_typed([](BlockBuilder builder, const Expr& output) { + return builder->EmitOutput(output); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutputVarBinding") + .set_body_typed([](BlockBuilder builder, VarBinding binding) { + return builder->EmitOutput(binding); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") + .set_body_typed([](BlockBuilder builder, String name_hint) { + return builder->name_table()->GetUniqueName(name_hint); + }); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderAddFunction") + .set_body_method(&BlockBuilderNode::AddFunction); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderUpdateFunction") + .set_body_method(&BlockBuilderNode::UpdateFunction); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderGetContextIRModule") + .set_body_method(&BlockBuilderNode::GetContextIRModule); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderCanProveShapeEqual") + .set_body_method(&BlockBuilderNode::CanProveShapeEqual); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderCurrentBlockIsDataFlow") + .set_body_method(&BlockBuilderNode::CurrentBlockIsDataFlow); + +TVM_REGISTER_GLOBAL("relax.BlockBuilderLookupBinding") + .set_body_method(&BlockBuilderNode::LookupBinding); +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc new file mode 100644 index 0000000000..680e62e7f9 --- /dev/null +++ b/src/relax/ir/dataflow_matcher.cc @@ -0,0 +1,754 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relax. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dataflow_matcher_impl.h" + +namespace tvm { +namespace relax { + +using tvm::arith::Analyzer; + +// Pattern Matcher +bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { + memo_.clear(); + matched_nodes_.clear(); + return VisitDFPattern(pattern, expr); +} + +static Expr TryGetValOfVar(const Expr& expr, const runtime::Map& var2val) { + if (var2val.empty()) return expr; + + // if not match, try to match value of var if expr is a var. + if (const VarNode* var = expr.as()) { + auto may = var2val.Get(GetRef(var)); + if (may.defined()) return may.value(); + } + + return expr; +} + +void DFPatternMatcher::ClearMap(size_t watermark) { + for (size_t i = watermark; i < matched_nodes_.size(); ++i) { + memo_.erase(matched_nodes_[i]); + } + matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end()); +} + +bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (memoize_ && memo_.count(pattern)) { + ICHECK_EQ(memo_[pattern].size(), 1); + return expr.same_as(memo_[pattern][0]); + } else { + size_t watermark = matched_nodes_.size(); + bool out = DFPatternFunctor::VisitDFPattern(pattern, expr); + if (out) { + memo_[pattern].push_back(expr); + matched_nodes_.push_back(pattern); + } else { + ClearMap(watermark); + } + return out; + } +} + +bool DFPatternMatcher::VisitDFPattern_(const OrPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const AndPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return VisitDFPattern(op->left, expr) && VisitDFPattern(op->right, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const NotPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return !VisitDFPattern(op->reject, expr); +} + +bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { + switch (rhs.type_code()) { + case kDLInt: + if (auto* val = lhs.as()) { + return val->value == rhs.operator int64_t(); + } + break; + case kDLFloat: + if (auto* val = lhs.as()) { + return val->value == rhs.operator double(); + } + break; + case kTVMStr: + if (auto* val = lhs.as()) { + return val->value == rhs.operator std::string(); + } else if (auto* val = lhs.as()) { + return val->data == rhs.operator std::string(); + } + break; + case kTVMDataType: + if (auto* val = lhs.as()) { + return rhs.operator std::string() == val->value; + } else if (auto* val = lhs.as()) { + return rhs.operator std::string() == val->data; + } else { + ICHECK(false) << "PatternMatcher: Unsupported TVMDataType " << lhs; + } + break; + case kTVMObjectHandle: + if (rhs.IsObjectRef()) { + if (auto* val = lhs.as()) { + return rhs.operator String() == val->value; + } else if (auto* val = lhs.as()) { + return rhs.operator String() == val->data; + } + } else { + // Compare the objects for structural equality + static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); + ICHECK(structural_equal) << "node.StructuralEqual is not registered."; + if ((*structural_equal)(lhs, GetRef(rhs.ptr()), false, true)) { + return true; + } + } + break; + default: + ICHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code(); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + bool matches = VisitDFPattern(attr_pattern->pattern, expr); + if (!matches) return matches; + VLOG(1) << "considering AttrPatternNode at:\n" << PrettyPrint(expr); + auto attributes = attr_pattern->attrs.as()->dict; + if (const auto* op_node = expr.as()) { + Op op = GetRef(op_node); + for (auto kv : attributes) { + auto attr_name = kv.first; + auto attr_value = kv.second; + if (Op::HasAttrMap(attr_name)) { + auto op_map = Op::GetAttrMap(attr_name); + if (op_map.count(op)) { + matches &= MatchRetValue(attr_value, op_map[op]); + } else { + matches = false; + } + } else { + matches = false; + } + } + } else if (auto* op = expr.as()) { + matches = true; + // TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this + // and replace the whole thing with a Visitor-based approach + ReflectionVTable* reflection = ReflectionVTable::Global(); + auto attrs_node = const_cast(op->attrs.get()); + // attrs may be undefined on non-op calls so we check first + std::vector attr_names; + if (attrs_node) { + attr_names = reflection->ListAttrNames(attrs_node); + } + for (auto kv : attributes) { + std::string attr = kv.first; + if (matches && std::find(attr_names.begin(), attr_names.end(), attr) != attr_names.end()) { + matches &= MatchRetValue(kv.second, reflection->GetAttr(attrs_node, attr)); + } else { + matches = false; + break; + } + } + } else if (auto* op = expr.as()) { + matches = true; + for (auto kv : attributes) { + if (matches && op->attrs.defined() && op->attrs->dict.count(kv.first)) { + matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]); + } else { + matches = false; + break; + } + } + } else { + matches = false; + } + return matches; +} + +bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + // utilities + auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { + if (op) { + if (auto* expr_pattern = op->op.as()) { + return expr_pattern->expr.as(); + } + } + return nullptr; + }; + auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) { + if (const auto* op_node = get_op_node(op)) { + if (op_node->name == op_type) { + return true; + } + } + return false; + }; + auto is_expr_op = [](const Expr& expr, std::string op_type) { + if (const auto* call_node = expr.as()) { + if (const auto* op_node = call_node->op.as()) { + if (op_node->name == op_type) { + return true; + } + } + } + return false; + }; + + // logic + auto watermark = matched_nodes_.size(); + if (const auto* call_node = expr.as()) { + auto matches_op = VisitDFPattern(op->op, call_node->op); + if (matches_op) { + auto watermark2 = matched_nodes_.size(); + + auto match_args = [this, &watermark2](const Array& pattern_args, auto expr_begin, + auto expr_end) { + bool matches = true; + auto pattern_it = pattern_args.begin(); + auto expr_it = expr_begin; + if (pattern_args.defined()) { + while (matches && pattern_it != pattern_args.end()) + matches &= VisitDFPattern(*(pattern_it++), *(expr_it++)); + } + if (!matches) ClearMap(watermark2); + return matches; + }; + + const size_t n_arg_pattern = op->args.size(); + const size_t n_arg_expr = call_node->args.size(); + // if allow variable args, #pattern must >= #expr. + if (op->varg_default_wildcard && n_arg_expr < n_arg_pattern) return false; + // if variable args are not allowed, #pattern must == #expr. + if (!op->varg_default_wildcard && n_arg_expr != n_arg_pattern) return false; + + // Standard case + if (match_args(op->args, call_node->args.begin(), call_node->args.end())) return true; + + // Commutative Matching + if (const OpNode* op_node = get_op_node(op)) { + if ((op_node->name == "relax.add") || (op_node->name == "relax.multiply")) { + if (match_args(op->args, call_node->args.rbegin(), call_node->args.rend())) { + return true; + } + } + } + } else { + ClearMap(watermark); + // associate divide/multiply + if (is_pattern_op(op, "relax.divide")) { + if (const auto* arg_node = op->args[0].as()) { + if (is_pattern_op(arg_node, "relax.multiply") && is_expr_op(expr, "relax.multiply") && + (is_expr_op(call_node->args[0], "relax.divide") || + is_expr_op(call_node->args[1], "relax.divide"))) { + bool out = false; + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}); + auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}); + out = VisitDFPattern(mul, expr); + if (out) { + return true; + } else { + ClearMap(watermark); + } + } + return out; + } + } + } + if (is_pattern_op(op, "relax.multiply")) { + // associate multiply/divide + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + if (auto* arg_node = op->args[arg_id].as()) { + if (is_pattern_op(arg_node, "relax.divide") && is_expr_op(expr, "relax.divide") && + (is_expr_op(call_node->args[0], "relax.multiply") || + is_expr_op(call_node->args[1], "relax.multiply"))) { + auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}); + auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}); + return VisitDFPattern(div, expr); + } + } + } + } + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + return StructuralEqual()(op->expr, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + bool matches = false; + if (const auto* func = expr.as()) { + matches = true; + if (op->params.defined()) { + size_t i = 0; + if (op->params.size() == func->params.size()) { + while (matches && i < op->params.size()) { + matches &= VisitDFPattern(op->params[i], func->params[i]); + ++i; + } + } else { + matches = false; + } + } + if (matches) { + matches &= VisitDFPattern(op->body, func->body); + } + } + return matches; +} + +bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (const auto* tuple_get_item_node = expr.as()) { + return (op->index == -1 || op->index == tuple_get_item_node->index) && + VisitDFPattern(op->tuple, tuple_get_item_node->tuple); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + bool matches = false; + if (const auto* tuple_node = expr.as()) { + matches = true; + if (op->fields.size() == tuple_node->fields.size()) { + size_t i = 0; + while (matches && i < op->fields.size()) { + matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]); + ++i; + } + } else { + matches = false; + } + } + return matches; +} + +bool DFPatternMatcher::TryUnorderedMatch(size_t idx, const tvm::Array patterns, + const tvm::Array fields, + std::vector& match_cache, + std::vector& matched) { + if (idx >= patterns.size()) return true; + constexpr int8_t kUnknown = -1; + auto this_pattern = patterns[idx]; + for (size_t i = 0; i < fields.size(); ++i) { + if (matched[i]) continue; + const size_t table_idx = idx * fields.size() + i; + match_cache[table_idx] = + kUnknown ? VisitDFPattern(this_pattern, fields[i]) : match_cache[table_idx]; + if (match_cache[table_idx]) { + // continue to match the rest; + matched[i] = true; + if (TryUnorderedMatch(idx + 1, patterns, fields, match_cache, matched)) return true; + matched[i] = false; + } + } + + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const UnorderedTuplePatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + + if (const auto* tuple_node = expr.as()) { + if (op->fields.size() == tuple_node->fields.size()) { + constexpr int8_t kUnknown = -1; + ICHECK_LE(op->fields.size(), std::numeric_limits::max()) << "Too many fields!"; + // dynamic programming. + std::vector match_cache(op->fields.size() * op->fields.size(), kUnknown); + std::vector field_match_bitmap(op->fields.size(), false); + return TryUnorderedMatch(0, op->fields, tuple_node->fields, match_cache, field_match_bitmap); + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + auto expr_type = expr.as()->checked_type(); + return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); +} + +static bool ShapeEqual(Analyzer* analyzer, const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) + if (!tir::is_one(analyzer->Simplify(lhs[i] == rhs[i]))) return false; + return true; +} + +bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) { + // no need to jump, as var.shape == value.shape + if (const ShapeExprNode* shape_expr = expr->shape().as()) + return ShapeEqual(&analyzer_, op->shape, shape_expr->values) && + VisitDFPattern(op->pattern, expr); + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (const ShapeExprNode* shape_expr = expr.as()) + return ShapeEqual(&analyzer_, op->fields, shape_expr->values); + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) { + // no need to jump, as var.dtype == value.dtype + auto expr_type = expr.as()->checked_type(); + if (const DynTensorTypeNode* tensor_type = expr_type.as()) { + return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) { + // We don't jump for var pattern, as there's no need to access its value to judge it. + if (const auto* var_node = expr.as()) { + // "" means any name. + return "" == op->name_hint() || op->name_hint() == var_node->name_hint(); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const ExternFuncPatternNode* op, const Expr& expr0) { + auto expr = TryGetValOfVar(expr0, var2val_); + if (const auto* extern_fn = expr.as()) { + return "" == op->global_symbol() || op->global_symbol() == extern_fn->global_symbol; + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr0) { + // constants can be binded to relax.Var as well. + auto expr = TryGetValOfVar(expr0, var2val_); + return expr.as() != nullptr; +} + +bool DFPatternMatcher::VisitDFPattern_(const DataflowVarPatternNode* op, const Expr& expr) { + // DataflowVar is inherented from Var, so dispatch it to VarPattern. + return expr->IsInstance() && + VisitDFPattern_(static_cast(op), expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const GlobalVarPatternNode* op, const Expr& expr) { + // GlobalVarPattern is not inherited from Var, so we need to handle it separately. + if (const auto* var_node = expr.as()) + return "" == op->name_hint() || op->name_hint() == var_node->name_hint; + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) { + return true; +} + +bool DFPatternMatcher::VisitDFPattern_(const RuntimeDepShapePatternNode* op, const Expr& expr) { + return expr->shape_->IsInstance() && VisitDFPattern(op->pattern, expr); +} + +bool MatchExpr(DFPattern pattern, Expr expr, Optional> var2val) { + if (var2val.defined()) // autojump is enabled with var2val. + return DFPatternMatcher(std::move(var2val.value())).Match(pattern, expr); + else + return DFPatternMatcher().Match(pattern, expr); +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); + +struct PNode { + const DFPatternNode* ptr; + const VarNode* matched = nullptr; + std::vector&>> children; + std::vector&>> parents; +}; + +struct RNode { + const VarNode* ptr; + const DFPatternNode* matched = nullptr; + std::vector children; + std::vector parents; +}; + +/** + * \brief This method try to match a real node and a pattern node along with its neighbors. + */ +static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, + const std::map>& def2use, + const std::map>& use2def) { + if (nullptr != p->matched && p->matched == r->ptr) return true; // matched before. + if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return false; + + std::stack> undo_stack{}; + + const auto commit = [&undo_stack](PNode* p, RNode* r) { + // match with each other. + p->matched = r->ptr; + r->matched = p->ptr; + undo_stack.emplace(p, r); + }; + + const auto quit = [&undo_stack] { + while (!undo_stack.empty()) { + auto& top = undo_stack.top(); + top.first->matched = nullptr; + top.second->matched = nullptr; + undo_stack.pop(); + } + return false; + }; + + commit(p, r); + + // match parent patterns. + for (auto& pparent_pairs : p->parents) { + PNode* pparent = pparent_pairs.first; + const std::vector& constraints = pparent_pairs.second; + + bool any_cons_sat = false; + for (auto& rparent : r->parents) { + // skip if mismatch. + if (rparent->matched && rparent->matched != pparent->ptr) continue; + + const auto& uses = def2use.at(rparent->ptr); + // skip if `rparent` is not used by `r`. + if (uses.cend() == uses.find(r->ptr)) continue; + + // check edge constraints. + bool cons_sat = true; + for (const auto& cons : constraints) { + if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) { + cons_sat = false; + break; + } + + if (-1 != cons.index) { + const auto& callees = use2def.at(r->ptr); + if (static_cast(cons.index) >= callees.size() || + rparent->ptr != callees[cons.index]) { + cons_sat = false; + break; + } + } + } + if (!cons_sat) continue; + any_cons_sat = true; + + // try all parent R nodes that are not matched yet. + // as long as ppattern can match one node. + if (!pparent->matched && try_match(pparent, rparent, m, def2use, use2def)) { + commit(pparent, rparent); + break; + } + } + if (!pparent->matched || !any_cons_sat) return quit(); + } + + // forward matching; + for (auto& pchild_pairs : p->children) { + PNode* pchild = pchild_pairs.first; + const std::vector& constraints = pchild_pairs.second; + bool any_cons_sat = false; + for (auto& rchild : r->children) { + if (rchild->matched && rchild->matched != pchild->ptr) continue; + + const auto& uses = def2use.at(r->ptr); + if (uses.cend() == uses.find(rchild->ptr)) continue; + + // check edge constraints. + bool all_cons_pass = true; + for (const auto& cons : constraints) { + if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) { + all_cons_pass = false; + break; + } + + if (-1 != cons.index) { + const auto& callees = use2def.at(rchild->ptr); + if (static_cast(cons.index) >= callees.size() || r->ptr != callees[cons.index]) { + all_cons_pass = false; + break; + } + } + } + if (!all_cons_pass) continue; + any_cons_sat = true; + + if (!pchild->matched && try_match(pchild, rchild, m, def2use, use2def)) { + commit(pchild, rchild); + break; + } + } + if (!pchild->matched || !any_cons_sat) return quit(); + } + + return true; +} + +class MatcherUseDefAnalysis : public relax::ExprVisitor { + public: + std::map> def2use; + // caller -> callee table. + std::map> caller2callees; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { + if (nullptr == cur_user_) return; + + def2use[op].insert(cur_user_); + caller2callees[cur_user_].push_back(op); + } + + void VisitExpr_(const DataflowVarNode* op) override { + VisitExpr_(static_cast(op)); + } +}; + +tvm::runtime::Map MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, + Optional start_hint, bool must_include_hint) { + tvm::runtime::Map ret{}; + // TODO(@ganler): Handle non-may external use. + ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; + ICHECK(!must_include_hint || start_hint.defined()) + << "must_include_hint is only supported with start_hint."; + + const auto var2val = AnalyzeVar2Value(dfb); + DFPatternMatcher matcher(var2val); + + // std::map> + MatcherUseDefAnalysis ud_analysis; + ud_analysis.VisitBindingBlock_(dfb.get()); + const auto& def2use = ud_analysis.def2use; + const auto& caller2callees = ud_analysis.caller2callees; + + // First construct a graph of PNode and RNode. + std::unordered_map var2node; + var2node.reserve(dfb->bindings.size()); + + for (const auto& du : def2use) { + const VarNode* cur_var = du.first; + const std::set& uses = du.second; + RNode& cur_node = var2node[cur_var]; + cur_node.ptr = cur_var; + for (const VarNode* use : uses) { + auto& use_node = var2node[use]; + use_node.ptr = use; + cur_node.children.push_back(&use_node); + use_node.parents.push_back(&cur_node); + } + } + + std::unordered_map pattern2node; + pattern2node.reserve(ctx->constraints.size()); + + for (const auto& def2use_pattern : ctx->constraints) { + const DFPatternNode* def_pattern = def2use_pattern.first.get(); + const std::map>& uses = def2use_pattern.second; + PNode& def_node = pattern2node[def_pattern]; + def_node.ptr = def_pattern; + def_node.children.reserve(uses.size()); + for (const auto& use : uses) { + const auto& cons = use.second; + const DFPatternNode* use_pattern = use.first.get(); + PNode& use_node = pattern2node[use_pattern]; + use_node.ptr = use_pattern; + use_node.parents.emplace_back(&def_node, std::ref(cons)); + def_node.children.emplace_back(&use_node, std::ref(cons)); + } + } + + if (start_hint.defined()) { + Var v = start_hint.value(); + auto rnode_ptr = var2node.find(v.get()); + for (auto& ppair : pattern2node) { + if (try_match(&ppair.second, &rnode_ptr->second, &matcher, def2use, caller2callees)) { + for (auto ppair : pattern2node) + ret.Set(GetRef(ppair.first), GetRef(ppair.second.matched)); + return ret; + } + } + + if (must_include_hint) return ret; + } + + PNode* pnode_start = &pattern2node.begin()->second; + + if (!pnode_start->matched) { + for (auto& rpair : var2node) { + if (start_hint.defined() && start_hint.value().get() == rpair.first) continue; + if (try_match(pnode_start, &rpair.second, &matcher, def2use, caller2callees)) { + for (auto ppair : pattern2node) + ret.Set(GetRef(ppair.first), GetRef(ppair.second.matched)); + + return ret; + } + } + } + + return ret; +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher_impl.h b/src/relax/ir/dataflow_matcher_impl.h new file mode 100644 index 0000000000..13f2b36840 --- /dev/null +++ b/src/relax/ir/dataflow_matcher_impl.h @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relax/dataflow_matcher_impl.h + * \brief The auxiliary data structure for dataflow matcher. + */ +#ifndef TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ +#define TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ + +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +class DFPatternMatcher : public DFPatternFunctor { + public: + using var2val_t = runtime::Map; + + explicit DFPatternMatcher() {} + explicit DFPatternMatcher(var2val_t var2val) : var2val_(std::move(var2val)) {} + bool Match(const DFPattern& pattern, const Expr& expr); + Map> GetMemo() { return Map>(memo_); } + + protected: + bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; + bool VisitDFPattern_(const OrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AndPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const NotPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; + + bool VisitDFPattern_(const RuntimeDepShapePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DataflowVarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const GlobalVarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExternFuncPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const PrimArrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const UnorderedTuplePatternNode* op, const Expr& expr) override; + + void ClearMap(size_t watermark); + bool TryUnorderedMatch(size_t idx, const tvm::Array patterns, + const tvm::Array fields, std::vector& match_cache, + std::vector& matched); + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; + var2val_t var2val_; + std::vector matched_nodes_; + arith::Analyzer analyzer_; + bool memoize_ = true; +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc new file mode 100644 index 0000000000..69ef767b0f --- /dev/null +++ b/src/relax/ir/dataflow_pattern.cc @@ -0,0 +1,638 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_pattern.cc + * \brief The dataflow pattern language for Relax (inherited from Relay). + */ + +#include +#include + +#include +#include + +#define RELAX_PATTERN_PRINTER_DEF(NODE_TYPE, REPR_LAMBDA) \ + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) \ + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { \ + auto* node = static_cast(ref.get()); \ + REPR_LAMBDA(p, node); \ + }) + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(ExternFuncPatternNode); +ExternFuncPattern::ExternFuncPattern(String global_symbol) { + ObjectPtr n = make_object(); + n->global_symbol_ = std::move(global_symbol); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.ExternFuncPattern").set_body_typed([](String global_symbol) { + return ExternFuncPattern(global_symbol); +}); +RELAX_PATTERN_PRINTER_DEF(ExternFuncPatternNode, [](auto p, auto node) { + p->stream << "ExternFuncPattern(" << node->global_symbol() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(VarPatternNode); +VarPattern::VarPattern(String name_hint) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.VarPattern").set_body_typed([](String name_hint) { + return VarPattern(name_hint); +}); +RELAX_PATTERN_PRINTER_DEF(VarPatternNode, [](auto p, auto node) { + p->stream << "VarPattern(" << node->name_hint() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(DataflowVarPatternNode); +TVM_REGISTER_GLOBAL("relax.dpl.DataflowVarPattern").set_body_typed([](String name_hint) { + return DataflowVarPattern(name_hint); +}); +DataflowVarPattern::DataflowVarPattern(String name_hint) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + data_ = std::move(n); +} +RELAX_PATTERN_PRINTER_DEF(DataflowVarPatternNode, [](auto p, auto node) { + p->stream << "DataflowVarPattern(" << node->name_hint() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(GlobalVarPatternNode); +GlobalVarPattern::GlobalVarPattern(String name_hint) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.GlobalVarPattern").set_body_typed([](String name_hint) { + return GlobalVarPattern(name_hint); +}); +RELAX_PATTERN_PRINTER_DEF(GlobalVarPatternNode, [](auto p, auto node) { + p->stream << "GlobalVarPattern(" << node->name_hint() << ")"; +}); + +TVM_REGISTER_NODE_TYPE(RuntimeDepShapePatternNode); +RuntimeDepShapePattern::RuntimeDepShapePattern(DFPattern root) { + ObjectPtr n = make_object(); + n->pattern = std::move(root); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.RuntimeDepShapePattern").set_body_typed([](DFPattern root) { + return RuntimeDepShapePattern(std::move(root)); +}); +RELAX_PATTERN_PRINTER_DEF(RuntimeDepShapePatternNode, [](auto p, auto node) { + p->stream << "RuntimeDepShapePattern(" << node->pattern << " has runtime-dep shape)"; +}); + +TVM_REGISTER_NODE_TYPE(ExprPatternNode); +ExprPattern::ExprPattern(Expr expr) { + ObjectPtr n = make_object(); + n->expr = std::move(expr); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.ExprPattern").set_body_typed([](Expr e) { return ExprPattern(e); }); +RELAX_PATTERN_PRINTER_DEF(ExprPatternNode, [](auto p, auto node) { p->Print(node->expr); }); + +TVM_REGISTER_NODE_TYPE(ConstantPatternNode); +TVM_REGISTER_GLOBAL("relax.dpl.ConstantPattern").set_body_typed([]() { + auto c = ConstantPattern(make_object()); + return c; +}); +RELAX_PATTERN_PRINTER_DEF(ConstantPatternNode, + [](auto p, auto node) { p->stream << "ConstantPattern()"; }); + +TVM_REGISTER_NODE_TYPE(CallPatternNode); +CallPattern::CallPattern(DFPattern op, Array args, bool varg_default_wildcard) { + ObjectPtr n = make_object(); + n->op = std::move(op); + n->args = std::move(args); + n->varg_default_wildcard = varg_default_wildcard; + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.CallPattern") + .set_body_typed([](DFPattern op, Array args, bool varg_default_wildcard) { + return CallPattern(op, args, varg_default_wildcard); + }); +RELAX_PATTERN_PRINTER_DEF(CallPatternNode, [](auto p, auto node) { + p->stream << node->op << "("; + for (size_t i = 0; i < node->args.size(); ++i) { + if (i != 0) p->stream << ", "; + p->stream << node->args[i]; + } + if (node->varg_default_wildcard) { + if (node->args.size() != 0) p->stream << ", "; + p->stream << "..."; + } + p->stream << ")"; +}); + +TVM_REGISTER_NODE_TYPE(PrimArrPatternNode); +PrimArrPattern::PrimArrPattern(Array arr) { + ObjectPtr n = make_object(); + n->fields = std::move(arr); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.PrimArrPattern").set_body_typed([](Array arr) { + return PrimArrPattern(std::move(arr)); +}); +RELAX_PATTERN_PRINTER_DEF(PrimArrPatternNode, [](auto p, auto node) { + p->stream << "PrimArrPattern(" << node->fields << ")"; +}); + +TVM_REGISTER_NODE_TYPE(FunctionPatternNode); +FunctionPattern::FunctionPattern(Array params, DFPattern body) { + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = std::move(body); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.FunctionPattern") + .set_body_typed([](Array params, DFPattern body) { + return FunctionPattern(params, body); + }); +RELAX_PATTERN_PRINTER_DEF(FunctionPatternNode, [](auto p, auto node) { + p->stream << "FunctionPattern(" << node->params << ", " << node->body << ")"; +}); + +TVM_REGISTER_NODE_TYPE(TuplePatternNode); +TuplePattern::TuplePattern(tvm::Array fields) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.TuplePattern").set_body_typed([](tvm::Array fields) { + return TuplePattern(fields); +}); +RELAX_PATTERN_PRINTER_DEF(TuplePatternNode, [](auto p, auto node) { + p->stream << "TuplePattern(" << node->fields << ")"; +}); + +TVM_REGISTER_NODE_TYPE(UnorderedTuplePatternNode); +UnorderedTuplePattern::UnorderedTuplePattern(tvm::Array fields) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.UnorderedTuplePattern") + .set_body_typed([](tvm::Array fields) { return UnorderedTuplePattern(fields); }); +RELAX_PATTERN_PRINTER_DEF(UnorderedTuplePatternNode, [](auto p, auto node) { + p->stream << "UnorderedTuplePattern(" << node->fields << ")"; +}); + +TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode); +TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { + ObjectPtr n = make_object(); + n->tuple = std::move(tuple); + n->index = index; + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.TupleGetItemPattern").set_body_typed([](DFPattern tuple, int index) { + return TupleGetItemPattern(tuple, index); +}); +RELAX_PATTERN_PRINTER_DEF(TupleGetItemPatternNode, [](auto p, auto node) { + p->stream << "TupleGetItemPattern(" << node->tuple << ", " << node->index << ")"; +}); + +TVM_REGISTER_NODE_TYPE(AndPatternNode); +AndPattern::AndPattern(DFPattern left, DFPattern right) { + ObjectPtr n = make_object(); + n->left = std::move(left); + n->right = std::move(right); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.AndPattern").set_body_typed([](DFPattern left, DFPattern right) { + return AndPattern(left, right); +}); +RELAX_PATTERN_PRINTER_DEF(AndPatternNode, [](auto p, auto node) { + p->stream << "AndPattern(" << node->left << " & " << node->right << ")"; +}); + +TVM_REGISTER_NODE_TYPE(OrPatternNode); +OrPattern::OrPattern(DFPattern left, DFPattern right) { + ObjectPtr n = make_object(); + n->left = std::move(left); + n->right = std::move(right); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.OrPattern").set_body_typed([](DFPattern left, DFPattern right) { + return OrPattern(left, right); +}); +RELAX_PATTERN_PRINTER_DEF(OrPatternNode, [](auto p, auto node) { + p->stream << "OrPattern(" << node->left << " | " << node->right << ")"; +}); + +TVM_REGISTER_NODE_TYPE(NotPatternNode); +NotPattern::NotPattern(DFPattern reject) { + ObjectPtr n = make_object(); + n->reject = std::move(reject); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.NotPattern").set_body_typed([](DFPattern reject) { + return NotPattern(reject); +}); +RELAX_PATTERN_PRINTER_DEF(NotPatternNode, + [](auto p, auto node) { p->stream << "!(" << node->reject << ")"; }); + +TVM_REGISTER_NODE_TYPE(WildcardPatternNode); +TVM_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() { + auto w = WildcardPattern(make_object()); + return w; +}); +RELAX_PATTERN_PRINTER_DEF(WildcardPatternNode, [](auto p, auto node) { p->stream << "*"; }); + +TVM_REGISTER_NODE_TYPE(TypePatternNode); +TypePattern::TypePattern(DFPattern pattern, Type type) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->type = std::move(type); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.TypePattern").set_body_typed([](DFPattern pattern, Type type) { + return TypePattern(pattern, type); +}); +RELAX_PATTERN_PRINTER_DEF(TypePatternNode, [](auto p, auto node) { + p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")"; +}); + +TVM_REGISTER_NODE_TYPE(ShapePatternNode); +ShapePattern::ShapePattern(DFPattern pattern, Array shape) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->shape = std::move(shape); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.ShapePattern") + .set_body_typed([](DFPattern pattern, Array shape) { + return ShapePattern(pattern, shape); + }); +RELAX_PATTERN_PRINTER_DEF(ShapePatternNode, [](auto p, auto node) { + p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")"; +}); + +TVM_REGISTER_NODE_TYPE(DataTypePatternNode); +DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->dtype = std::move(dtype); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.DataTypePattern") + .set_body_typed([](DFPattern pattern, DataType dtype) { + return DataTypePattern(pattern, dtype); + }); +RELAX_PATTERN_PRINTER_DEF(DataTypePatternNode, [](auto p, auto node) { + p->stream << "DataTypePattern(" << node->pattern << " has dtype " << node->dtype << ")"; +}); + +TVM_REGISTER_NODE_TYPE(AttrPatternNode); +AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->attrs = std::move(attrs); + data_ = std::move(n); +} +TVM_REGISTER_GLOBAL("relax.dpl.AttrPattern").set_body_typed([](DFPattern pattern, DictAttrs attrs) { + return AttrPattern(pattern, attrs); +}); +RELAX_PATTERN_PRINTER_DEF(AttrPatternNode, [](auto p, auto node) { + p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; +}); + +class DFPatternDuplicator : public DFPatternFunctor { + public: + DFPattern VisitDFPattern(const DFPattern& pattern) override { + return DFPatternFunctor::VisitDFPattern(pattern); + } + DFPattern VisitDFPattern_(const OrPatternNode* op) override { + return OrPattern(op->left, op->right); + } + DFPattern VisitDFPattern_(const AndPatternNode* op) override { + return AndPattern(op->left, op->right); + } + DFPattern VisitDFPattern_(const NotPatternNode* op) override { return NotPattern(op->reject); } + DFPattern VisitDFPattern_(const VarPatternNode* op) override { return VarPattern(op->name); } + DFPattern VisitDFPattern_(const ConstantPatternNode* op) override { + return ConstantPattern(make_object()); + } + DFPattern VisitDFPattern_(const WildcardPatternNode* op) override { + return WildcardPattern(make_object()); + } + DFPattern VisitDFPattern_(const ExprPatternNode* op) override { return ExprPattern(op->expr); } + DFPattern VisitDFPattern_(const GlobalVarPatternNode* op) override { + return GlobalVarPattern(op->name); + } + DFPattern VisitDFPattern_(const TuplePatternNode* op) override { + return TuplePattern(op->fields); + } + DFPattern VisitDFPattern_(const UnorderedTuplePatternNode* op) override { + return UnorderedTuplePattern(op->fields); + } + DFPattern VisitDFPattern_(const TupleGetItemPatternNode* op) override { + return TupleGetItemPattern(op->tuple, op->index); + } + DFPattern VisitDFPattern_(const CallPatternNode* op) override { + return CallPattern(op->op, op->args); + } + DFPattern VisitDFPattern_(const DataTypePatternNode* op) override { + return DataTypePattern(op->pattern, op->dtype); + } + DFPattern VisitDFPattern_(const FunctionPatternNode* op) override { + return FunctionPattern(op->params, op->body); + } + DFPattern VisitDFPattern_(const ShapePatternNode* op) override { + return ShapePattern(op->pattern, op->shape); + } + DFPattern VisitDFPattern_(const TypePatternNode* op) override { + return TypePattern(op->pattern, op->type); + } + DFPattern VisitDFPattern_(const RuntimeDepShapePatternNode* op) override { + return RuntimeDepShapePattern(op->pattern); + } + DFPattern VisitDFPattern_(const DataflowVarPatternNode* op) override { + return DataflowVarPattern(op->name); + } + DFPattern VisitDFPattern_(const ExternFuncPatternNode* op) override { + return ExternFuncPattern(op->global_symbol()); + } + DFPattern VisitDFPattern_(const PrimArrPatternNode* op) override { + return PrimArrPattern(op->fields); + } +}; + +// Syntatic Sugar +CallPattern DFPattern::operator()(const std::vector& args) const { + return CallPattern(*this, Array(args)); +} +OrPattern DFPattern::operator|(const DFPattern& other) const { return OrPattern(*this, other); } + +AndPattern DFPattern::operator&(const DFPattern& other) const { return AndPattern(*this, other); } + +NotPattern DFPattern::operator~() const { return NotPattern(*this); } + +AttrPattern DFPattern::HasAttr(const Map& attrs) const { + return AttrPattern(*this, DictAttrs(attrs)); +} +TypePattern DFPattern::HasType(const Type& type) const { return TypePattern(*this, type); } +DataTypePattern DFPattern::HasDtype(const DataType& dtype) const { + return DataTypePattern(*this, dtype); +} +DataTypePattern DFPattern::HasDtype(const std::string& dtype) const { + return HasDtype(DataType(runtime::String2DLDataType(dtype))); +} +ShapePattern DFPattern::HasShape(const Array& shape) const { + return ShapePattern(*this, shape); +} +RuntimeDepShapePattern DFPattern::HasRuntimeDepShape() const { + return RuntimeDepShapePattern(*this); +} + +DFPattern::operator PatternSeq() const { return PatternSeq{{*this}}; } + +std::stack& pattern_ctx_stack() { + thread_local std::stack graph_pattern_managers; + return graph_pattern_managers; +} + +PatternContext PatternContext::Current() { + ICHECK(!pattern_ctx_stack().empty()) << "No active PatternContext found."; + return pattern_ctx_stack().top(); +} + +PatternContext::PatternContext(bool incremental) { + auto n = make_object(); + if (incremental) { + ICHECK(!pattern_ctx_stack().empty()) + << "Incremental context needs to be built inside a existing context."; + n->allow_extern_use = pattern_ctx_stack().top()->allow_extern_use; + n->constraints = pattern_ctx_stack().top()->constraints; + } + + data_ = std::move(n); +} + +void PatternContext::EnterWithScope() { pattern_ctx_stack().push(*this); } + +void PatternContext::ExitWithScope() { + ICHECK(pattern_ctx_stack().top().same_as(*this)); + pattern_ctx_stack().pop(); +} + +static void sync_graph_constraints(const DFPattern& lhs, const DFPattern& rhs, PairCons pcon) { + PatternContext::Current().add_constraint(lhs, rhs, pcon); +} + +TVM_REGISTER_NODE_TYPE(PatternSeqNode); +PatternSeq::PatternSeq(DFPattern init_pattern) { + ObjectPtr n = make_object(); + n->patterns = {init_pattern}; + n->pair_constraints = {}; + data_ = std::move(n); +} +PatternSeq::PatternSeq(tvm::Array patterns, bool only_used_by) { + ICHECK_GE(patterns.size(), 1) << "PatternSeq must have at least one pattern"; + const auto cons = PairCons(only_used_by ? PairCons::kOnlyUsedBy : PairCons::kUsedBy); + + ObjectPtr n = make_object(); + n->patterns = std::move(patterns); + n->pair_constraints = std::vector(n->patterns.size() - 1, cons); + data_ = std::move(n); +} + +PatternSeq PatternSeq::UsedBy(PatternSeq other, int index) const { + return relax::UsedBy(*this, other, index); +} + +PatternSeq PatternSeq::OnlyUsedBy(PatternSeq other, int index) const { + return relax::OnlyUsedBy(*this, other, index); +} + +PatternSeq PatternSeq::dup() const { + PatternSeq ret; + + ObjectPtr n = make_object(); + n->patterns = Array{}; + n->patterns.reserve(get()->patterns.size()); + n->pair_constraints = this->get()->pair_constraints; + + for (size_t i = 0; i < get()->patterns.size(); ++i) { + n->patterns.push_back(get()->patterns[i].dup()); + if (i >= 1) + sync_graph_constraints(n->patterns[i - 1], n->patterns[i], n->pair_constraints[i - 1]); + } + + ret.data_ = std::move(n); + + return ret; +} +TVM_REGISTER_GLOBAL("relax.dpl.PatternSeq") + .set_body_typed([](Array patterns, bool only_used_by) { + return PatternSeq(std::move(patterns), only_used_by); + }); +RELAX_PATTERN_PRINTER_DEF(PatternSeqNode, [](auto p, auto node) { + p->stream << "["; + for (size_t i = 0; i < node->patterns.size(); ++i) { + if (i != 0) + p->stream << (PairCons::kOnlyUsedBy == node->pair_constraints[i].type ? " >> " : " ^ "); + p->stream << node->patterns[i]; + } + p->stream << "]"; +}); + +TVM_REGISTER_GLOBAL("relax.dpl.used_by") + .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { + return lhs.UsedBy(rhs, index); + }); + +TVM_REGISTER_GLOBAL("relax.dpl.only_used_by") + .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { + return lhs.OnlyUsedBy(rhs, index); + }); + +PatternSeq UsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { + PatternSeq ret; + + const auto constraint = PairCons{PairCons::kOnlyUsedBy, index}; + + sync_graph_constraints(lhs->patterns.back(), rhs->patterns.front(), + PairCons{PairCons::kUsedBy, index}); + + Array patterns; + patterns.reserve(lhs->patterns.size() + rhs->patterns.size()); + patterns.insert(patterns.end(), lhs->patterns.begin(), lhs->patterns.end()); + patterns.insert(patterns.end(), rhs->patterns.begin(), rhs->patterns.end()); + + std::vector pair_constraints = lhs->pair_constraints; + pair_constraints.reserve(pair_constraints.size() + rhs->pair_constraints.size() + 1); + pair_constraints.push_back(constraint); + pair_constraints.insert(pair_constraints.end(), rhs->pair_constraints.begin(), + rhs->pair_constraints.end()); + + ObjectPtr n = make_object(); + n->patterns = std::move(patterns); + n->pair_constraints = std::move(pair_constraints); + ret.data_ = std::move(n); + + return ret; +} +PatternSeq operator^(const PatternSeq& lhs, const PatternSeq& rhs) { return lhs.UsedBy(rhs); } + +PatternSeq OnlyUsedBy(const PatternSeq& lhs, const PatternSeq& rhs, int index) { + PatternSeq ret; + + const auto constraint = PairCons{PairCons::kOnlyUsedBy, index}; + + sync_graph_constraints(lhs->patterns.back(), rhs->patterns.front(), constraint); + + Array patterns; + patterns.reserve(lhs->patterns.size() + rhs->patterns.size()); + patterns.insert(patterns.end(), lhs->patterns.begin(), lhs->patterns.end()); + patterns.insert(patterns.end(), rhs->patterns.begin(), rhs->patterns.end()); + + std::vector pair_constraints = lhs->pair_constraints; + pair_constraints.reserve(pair_constraints.size() + rhs->pair_constraints.size() + 1); + pair_constraints.push_back(constraint); + pair_constraints.insert(pair_constraints.end(), rhs->pair_constraints.begin(), + rhs->pair_constraints.end()); + + ObjectPtr n = make_object(); + n->patterns = std::move(patterns); + n->pair_constraints = std::move(pair_constraints); + ret.data_ = std::move(n); + + return ret; +} +PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs) { return lhs.OnlyUsedBy(rhs); } + +VarPattern IsVar(const String& name) { return VarPattern(name); } +ConstantPattern IsConst() { return ConstantPattern(make_object()); } +WildcardPattern Wildcard() { return WildcardPattern(make_object()); } +ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); } +ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); } +CallPattern IsCallTIR(const String& name, Optional var_args, + Optional> oshape) { + DFPattern arg_pattern; + if (!var_args.defined()) { + arg_pattern = Wildcard(); + } else { + arg_pattern = var_args.value(); + } + + DFPattern shape_pattern; + if (!oshape.defined()) { + shape_pattern = Wildcard(); + } else { + shape_pattern = PrimArrPattern(oshape.value()); + } + + return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern, shape_pattern); +} + +CallPattern IsCallTIR(const String& name, TuplePattern var_args, Array> oshapes) { + Array shape_patterns; + shape_patterns.reserve(oshapes.size()); + for (auto shape : oshapes) shape_patterns.push_back(PrimArrPattern(std::move(shape))); + + return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args, + IsTuple(std::move(shape_patterns))); +} + +DFPattern IsTuple(const Array& fields, bool unordered) { + if (unordered) + return UnorderedTuplePattern(fields); + else + return TuplePattern(fields); +} +TupleGetItemPattern IsTupleGetItem(const DFPattern tuple, int index) { + return TupleGetItemPattern(tuple, index); +} + +DFPattern DFPattern::dup() const { + auto pattern = DFPatternDuplicator().VisitDFPattern(*this); + return pattern; +} + +TVM_REGISTER_GLOBAL("relax.dpl.dup_pattern").set_body_typed([](DFPattern pattern) { + return pattern.dup(); +}); + +TVM_REGISTER_GLOBAL("relax.dpl.dup_seq").set_body_typed([](PatternSeq seq) { return seq.dup(); }); + +TVM_REGISTER_GLOBAL("relax.dpl.PatternContext").set_body_typed([](bool incre) { + return PatternContext(incre); +}); + +TVM_REGISTER_GLOBAL("relax.dpl.current_context").set_body_typed([] { + return PatternContext::Current(); +}); + +class PatternContext::Internal { + public: + static void EnterScope(PatternContext pass_ctx) { pass_ctx.EnterWithScope(); } + static void ExitScope(PatternContext pass_ctx) { pass_ctx.ExitWithScope(); } +}; + +TVM_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed(PatternContext::Internal::EnterScope); + +TVM_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed(PatternContext::Internal::ExitScope); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_pattern_functor.cc b/src/relax/ir/dataflow_pattern_functor.cc new file mode 100644 index 0000000000..d05ebee4ca --- /dev/null +++ b/src/relax/ir/dataflow_pattern_functor.cc @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relay. + */ + +#include + +namespace tvm { +namespace relax { + +// DFPatternVisitor + +void DFPatternVisitor::VisitDFPattern(const DFPattern& pattern) { + if (this->visited_.count(pattern.get()) == 0) { + visited_.insert(pattern.get()); + DFPatternFunctor::VisitDFPattern(pattern); + } +} + +void DFPatternVisitor::VisitDFPattern_(const OrPatternNode* op) { + VisitDFPattern(op->left); + VisitDFPattern(op->right); +} + +void DFPatternVisitor::VisitDFPattern_(const AndPatternNode* op) { + VisitDFPattern(op->left); + VisitDFPattern(op->right); +} + +void DFPatternVisitor::VisitDFPattern_(const NotPatternNode* op) { VisitDFPattern(op->reject); } + +void DFPatternVisitor::VisitDFPattern_(const AttrPatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) { + VisitDFPattern(op->op); + if (op->args.defined()) { + for (auto arg : op->args) { + VisitDFPattern(arg); + } + } +} + +void DFPatternVisitor::VisitDFPattern_(const DataTypePatternNode* op) { + VisitDFPattern(op->pattern); +} + +void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {} + +void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) { + if (op->params.defined()) { + for (auto param : op->params) { + VisitDFPattern(param); + } + } + VisitDFPattern(op->body); +} + +void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) { + VisitDFPattern(op->tuple); +} + +void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) { + if (op->fields.defined()) { + for (auto field : op->fields) { + VisitDFPattern(field); + } + } +} + +void DFPatternVisitor::VisitDFPattern_(const UnorderedTuplePatternNode* op) { + if (op->fields.defined()) { + for (auto field : op->fields) { + VisitDFPattern(field); + } + } +} + +void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); } + +// leaf nodes. +void DFPatternVisitor::VisitDFPattern_(const PrimArrPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const RuntimeDepShapePatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const DataflowVarPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const GlobalVarPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const ExternFuncPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc new file mode 100644 index 0000000000..0616e30f07 --- /dev/null +++ b/src/relax/ir/emit_te.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/src/ir/emit_te.cc + */ +#include "./emit_te.h" + +#include + +namespace tvm { +namespace relax { + +// RXPlaceholderOpNode +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "rxplaceholder(" << op->name << ", " << op << ")"; + }); + +TVM_REGISTER_NODE_TYPE(RXPlaceholderOpNode); + +te::Tensor TETensor(Expr value, std::string name) { + auto n = make_object(); + n->name = name; + n->value = value; + + // If the value is a constant, it might come as an argument of EmitTE and thus its shape and + // checked-type might not be properly set. In this case we set the shape and dtype of the returned + // TE tensor. + if (const auto* constant = value.as()) { + n->dtype = DataType(constant->data->dtype); + + int ndim = constant->data->ndim; + ShapeTuple shape_tuple = constant->data.Shape(); + Array shape; + shape.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + shape.push_back(IntImm(DataType::Int(32), shape_tuple[i])); + } + n->shape = std::move(shape); + return te::PlaceholderOp(n).output(0); + } + + Expr shape_expr = value->shape(); + CHECK(shape_expr->IsInstance()) + << "ValueError: Expression does not have an known symbolic shape, please consider use " + "match_shape " + << "to constrain the shape before passing into te_tensor"; + Array shape = Downcast(shape_expr)->values; + n->shape = shape; + Type type = value->checked_type(); + ICHECK(type->IsInstance()) + << "ValueError: Expression should have a inferred DynTensorType: " << type->GetTypeKey(); + DataType dtype = Downcast(type)->dtype; + n->dtype = dtype; + return te::PlaceholderOp(n).output(0); +} + +TVM_REGISTER_GLOBAL("relax.TETensor").set_body_typed(TETensor); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h new file mode 100644 index 0000000000..917d33fffa --- /dev/null +++ b/src/relax/ir/emit_te.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/src/ir/emit_te.h + * \brief Tensor expression extension in Relax. + */ +#ifndef TVM_RELAX_IR_EMIT_TE_H_ +#define TVM_RELAX_IR_EMIT_TE_H_ + +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! + * \brief A placeholder op that represents a relax expression. + */ +class RXPlaceholderOpNode : public te::PlaceholderOpNode { + public: + /*! \brief The relax expression. */ + Expr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("tag", &tag); + v->Visit("attrs", &attrs); + v->Visit("value", &value); + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + } + + static constexpr const char* _type_key = "RXPlaceholderOp"; + TVM_DECLARE_FINAL_OBJECT_INFO(RXPlaceholderOpNode, te::PlaceholderOpNode); +}; + +/*! + * \brief create a te tensor from relax expression. + * \param value The relax experession. + * \param name The name of the tensor. + */ +te::Tensor TETensor(Expr value, std::string name = "rxplaceholder"); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_IR_EMIT_TE_H_ diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc new file mode 100644 index 0000000000..22f796cfc0 --- /dev/null +++ b/src/relax/ir/expr.cc @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +namespace tvm { + +RelayExpr RelayExprNode::shape() const { + if (this->shape_.defined()) { + return Downcast(this->shape_); + } + static const Op& op = Op::Get("relax.shape_of"); + RelayExpr self = GetRef(this); + return relay::Call(op, {self}, {}, {}); +} + +TVM_REGISTER_GLOBAL("ir.RelayExprShape").set_body_method(&RelayExprNode::shape); + +namespace relax { +using tvm::ReprPrinter; +using tvm::runtime::Optional; + +TVM_REGISTER_NODE_TYPE(ShapeExprNode); + +ShapeExpr::ShapeExpr(Array values, Span span) { + ObjectPtr n = make_object(); + n->values = std::move(values); + n->span = span; + n->shape_ = NullOpt; + n->checked_type_ = ShapeType(); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, Span span) { + return ShapeExpr(values, span); +}); + +TVM_REGISTER_NODE_TYPE(RuntimeDepShapeNode); + +RuntimeDepShape::RuntimeDepShape(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.RuntimeDepShape").set_body_typed([](Span span) { + return RuntimeDepShape(span); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + const ShapeExprNode* node = static_cast(ref.get()); + p->stream << "ShapeExpr("; + for (auto it = node->values.begin(); it != node->values.end(); it++) { + if (it != node->values.begin()) { + p->stream << ", "; + } + p->stream << *it; + } + p->stream << ")"; + }); + +TVM_REGISTER_NODE_TYPE(VarNode); + +Var::Var(Id vid, Optional shape_annotation, Optional type_annotation, Span span) { + ObjectPtr n = make_object(); + n->vid = std::move(vid); + n->shape_ = std::move(shape_annotation); + if (type_annotation) { + n->checked_type_ = std::move(type_annotation.value()); + } + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.Var") + .set_body_typed([](String name_hint, Optional shape_annotation, + Optional type_annotation, Span span) { + return Var(name_hint, shape_annotation, type_annotation, span); + }); + +TVM_REGISTER_GLOBAL("relax.VarFromId") + .set_body_typed([](Id vid, Optional shape_annotation, Optional type_annotation, + Span span) { return Var(vid, shape_annotation, type_annotation, span); }); + +TVM_REGISTER_NODE_TYPE(DataflowVarNode); + +DataflowVar::DataflowVar(Id vid, Optional shape_annotation, Optional type_annotation, + Span span) { + ObjectPtr n = make_object(); + n->vid = std::move(vid); + n->shape_ = std::move(shape_annotation); + if (type_annotation) { + n->checked_type_ = std::move(type_annotation.value()); + } + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowVar") + .set_body_typed([](String name_hint, Optional shape_annotation, + Optional type_annotation, Span span) { + return DataflowVar(name_hint, shape_annotation, type_annotation, span); + }); + +TVM_REGISTER_GLOBAL("relax.DataflowVarFromId") + .set_body_typed([](Id vid, Optional shape_annotation, Optional type_annotation, + Span span) { + return DataflowVar(vid, shape_annotation, type_annotation, span); + }); + +TVM_REGISTER_NODE_TYPE(BindingNode); + +TVM_REGISTER_NODE_TYPE(MatchShapeNode); + +MatchShape::MatchShape(Expr value, Array pattern, Var var, Span span) { + ObjectPtr n = make_object(); + n->value = std::move(value); + n->pattern = std::move(pattern); + n->var = std::move(var); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.MatchShape") + .set_body_typed([](Expr value, Array pattern, Var var, Span span) { + return MatchShape(value, pattern, var, span); + }); + +TVM_REGISTER_NODE_TYPE(VarBindingNode); + +VarBinding::VarBinding(Var var, Expr value, Span span) { + ObjectPtr n = make_object(); + n->var = std::move(var); + n->value = std::move(value); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, Span span) { + return VarBinding(var, value, span); +}); + +TVM_REGISTER_NODE_TYPE(BindingBlockNode); + +BindingBlock::BindingBlock(Array bindings, Span span) { + ObjectPtr n = make_object(); + n->bindings = std::move(bindings); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array bindings, Span span) { + return BindingBlock(bindings, span); +}); + +TVM_REGISTER_NODE_TYPE(DataflowBlockNode); + +DataflowBlock::DataflowBlock(Array bindings, Span span) { + ObjectPtr n = make_object(); + n->bindings = std::move(bindings); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array bindings, Span span) { + return DataflowBlock(bindings, span); +}); + +TVM_REGISTER_NODE_TYPE(SeqExprNode); + +SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { + ObjectPtr n = make_object(); + n->blocks = std::move(blocks); + n->body = std::move(body); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.SeqExpr") + .set_body_typed([](Array blocks, Expr body, Span span) { + return SeqExpr(blocks, body, span); + }); + +TVM_REGISTER_NODE_TYPE(FunctionNode); + +Function::Function(Array params, Expr body, Type ret_type, Expr ret_shape, DictAttrs attrs, + Span span) { + // Set the function type. + // For function, we take a conservative approach and require the function type + // to be known at construction time. + Array param_types; + for (Var param : params) { + CHECK(param->checked_type_.defined()) + << "relax.Function requires params to contain checked_type_"; + param_types.push_back(param->checked_type_); + } + + if (!ret_type.defined()) { + CHECK(body->checked_type_.defined()) + << "relax.Function requires body to contain deduced checked_type_" + << " or ret_type to be supplied"; + ret_type = body->checked_type_; + } else { + if (body->checked_type_.defined()) { + CHECK(IsBaseOf(ret_type, body->checked_type_)) + << "relax.Function requires the deduced body->checked_type_ to be a subtype of the " + "annotated ret_type but meet body->checked_type_: " + << body->checked_type_ << ", ret_type: " << ret_type; + + // Use the more refined body->checked_type_ as the return type. + ret_type = body->checked_type_; + } + } + auto func_type = FuncType(param_types, ret_type, {}, {}); + + // set the fields + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = std::move(body); + n->ret_type = std::move(ret_type); + n->ret_shape = std::move(ret_shape); + n->checked_type_ = std::move(func_type); + n->attrs = std::move(attrs); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.Function") + .set_body_typed([](Array params, Expr body, Type ret_type, Expr ret_shape, DictAttrs attrs, + Span span) { + return Function(params, body, ret_type, ret_shape, attrs, span); + }); + +Function Function::CreateUnchecked(Array params, Expr body, Type ret_type, Expr ret_shape, + DictAttrs attrs, Span span) { + for (Var param : params) { + ICHECK(param->checked_type_.defined()) + << "relax.Function requires params to contain checked_type_."; + } + + // set the fields + ObjectPtr n = make_object(); + n->params = std::move(params); + n->body = std::move(body); + n->ret_type = std::move(ret_type); + n->ret_shape = std::move(ret_shape); + n->attrs = std::move(attrs); + n->span = std::move(span); + return Function(std::move(n)); +} + +TVM_REGISTER_GLOBAL("relax.Function_CreateUnchecked") + .set_body_typed([](Array params, Expr body, Type ret_type, Expr ret_shape, DictAttrs attrs, + Span span) { + return Function::CreateUnchecked(params, body, ret_type, ret_shape, attrs, span); + }); + +TVM_REGISTER_NODE_TYPE(ExternFuncNode); + +ExternFunc::ExternFunc(String global_symbol, Span span) { + ObjectPtr n = make_object(); + n->global_symbol = std::move(global_symbol); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, Span span) { + return ExternFunc(global_symbol, span); +}); + +void UpdateType(Expr expr, Type type) { + ICHECK(!expr->checked_type_.defined() || tvm::StructuralEqual()(expr->checked_type_, type)) + << "the checked_type_ of the Expr must not be nullptr for idempotency"; + expr->checked_type_ = type; +} + +TVM_REGISTER_GLOBAL("relax.UpdateType").set_body_typed([](Expr expr, Type type) { + UpdateType(expr, type); +}); + +void UpdateShape(Expr expr, Optional shape) { + ICHECK(!expr->shape_.defined()) << "the shape_ of the Expr must not be nullptr for idempotency"; + expr->shape_ = shape; +} + +TVM_REGISTER_GLOBAL("relax.UpdateShape").set_body_typed([](Expr expr, Optional shape) { + UpdateShape(expr, shape); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc new file mode 100644 index 0000000000..cae5baf74c --- /dev/null +++ b/src/relax/ir/expr_functor.cc @@ -0,0 +1,794 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/expr_functor.cc + * \brief A wrapper around ExprFunctor which functionally updates the AST. + * + * ExprMutator uses memoization and self return in order to amortize + * the cost of using functional updates. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +// ================== +// ExprVisitor + +void ExprVisitor::VisitExpr(const Expr& expr) { ExprFunctor::VisitExpr(expr); } + +void ExprVisitor::VisitExpr_(const ConstantNode* op) { + this->VisitSpan(op->span); + + if (op->shape_) { + this->VisitExpr(Downcast(op->shape_.value())); + } +} + +void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitExpr_(const TupleNode* op) { + this->VisitSpan(op->span); + for (Expr field : op->fields) { + this->VisitExpr(field); + } + + if (op->shape_) { + this->VisitExpr(Downcast(op->shape_.value())); + } +} + +// Visit the use-site of a defined Var +void ExprVisitor::VisitExpr_(const VarNode* op) { this->VisitSpan(op->span); } + +// Visit the use-site of a defined DataflowVar +void ExprVisitor::VisitExpr_(const DataflowVarNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitExpr_(const FunctionNode* op) { + this->VisitSpan(op->span); + for (Var param : op->params) { + this->VisitVarDef(param); + } + + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitExpr_(const CallNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->op); + + for (Type ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + + for (Expr arg : op->args) { + this->VisitExpr(arg); + } + + if (op->shape_) { + this->VisitExpr(Downcast(op->shape_.value())); + } +} + +void ExprVisitor::VisitExpr_(const IfNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->cond); + this->VisitExpr(op->true_branch); + this->VisitExpr(op->false_branch); +} + +void ExprVisitor::VisitExpr_(const OpNode* op) {} + +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { + this->VisitSpan(op->span); + this->VisitExpr(op->tuple); +} + +void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitExpr_(const RuntimeDepShapeNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { this->VisitSpan(op->span); } + +void ExprVisitor::VisitExpr_(const SeqExprNode* op) { + this->VisitSpan(op->span); + for (BindingBlock block : op->blocks) { + this->VisitBindingBlock(block); + } + this->VisitExpr(op->body); +} + +void ExprVisitor::VisitType(const Type& t) {} + +void ExprVisitor::VisitSpan(const Span& span) {} + +void ExprVisitor::VisitBinding_(const VarBindingNode* binding) { + this->VisitExpr(binding->value); + this->VisitVarDef(binding->var); +} + +void ExprVisitor::VisitBinding_(const MatchShapeNode* binding) { + this->VisitExpr(binding->value); + // TODO(ziheng): should we change pattern from + // Array to ShapeExpr? + this->VisitExpr(ShapeExpr(binding->pattern)); + if (binding->var.defined()) { + this->VisitVarDef(binding->var); + } +} + +void ExprVisitor::VisitBindingBlock_(const BindingBlockNode* block) { + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } +} + +void ExprVisitor::VisitBindingBlock_(const DataflowBlockNode* block) { + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } +} + +void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) { + this->VisitSpan(var->span); + + if (var->shape_) { + this->VisitExpr(Downcast(var->shape_.value())); + } +} + +void ExprVisitor::VisitVarDef_(const VarNode* var) { + this->VisitSpan(var->span); + + if (var->shape_) { + this->VisitExpr(Downcast(var->shape_.value())); + } +} + +void ExprVisitor::VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } +} + +void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { + if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } +} + +void ExprVisitor::VisitVarDef(const Var& var) { + if (const auto* node = var.as()) { + VisitVarDef_(node); + } else if (const auto* node = var.as()) { + VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } +} + +class ExprApplyVisit : public ExprVisitor { + public: + explicit ExprApplyVisit(std::function f) : f_(f) {} + + void VisitExpr(const Expr& e) final { + ExprVisitor::VisitExpr(e); + f_(e); + } + + private: + std::function f_; +}; + +void PostOrderVisit(const Expr& e, std::function fvisit) { + ExprApplyVisit(fvisit).VisitExpr(e); +} + +TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) { + PostOrderVisit(expr, [f](const Expr& n) { f(n); }); +}); + +// ================== +// ExprMutatorBase + +Expr ExprMutatorBase::VisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } + +Expr ExprMutatorBase::VisitExpr_(const ConstantNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const GlobalVarNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { + bool unchanged = true; + tvm::Array fields; + for (Expr field : op->fields) { + Expr new_field = this->VisitExpr(field); + fields.push_back(new_field); + unchanged &= new_field.same_as(field); + } + + if (unchanged) { + return GetRef(op); + } else { + Expr new_tuple = Tuple(fields, op->span); + return new_tuple; + } +} + +// Visit the use-site of a defined Var +Expr ExprMutatorBase::VisitExpr_(const VarNode* op) { return GetRef(op); } + +// Visit the use-site of a defined DataflowVar +Expr ExprMutatorBase::VisitExpr_(const DataflowVarNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) { + Expr body = this->VisitExpr(op->body); + Expr ret_shape = this->VisitExpr(op->ret_shape); + + if (body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, body, op->ret_type, op->ret_shape, op->attrs); + } +} + +Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { + Expr new_op = this->VisitExpr(call_node->op); + bool unchanged = call_node->op.same_as(new_op); + + tvm::Array ty_args; + for (Type ty_arg : call_node->type_args) { + Type new_ty_arg = this->VisitType(ty_arg); + ty_args.push_back(new_ty_arg); + unchanged &= new_ty_arg.same_as(ty_arg); + } + + tvm::Array call_args; + for (Expr arg : call_node->args) { + Expr new_arg = this->VisitExpr(arg); + call_args.push_back(new_arg); + unchanged &= new_arg.same_as(arg); + } + + if (unchanged) { + return GetRef(call_node); + } else { + Expr new_call = Call(new_op, call_args, call_node->attrs, ty_args, call_node->span); + return new_call; + } +} + +Expr ExprMutatorBase::VisitExpr_(const IfNode* op) { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitExpr(op->true_branch); + Expr false_b = this->VisitExpr(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const OpNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { + auto t = this->VisitExpr(op->tuple); + if (op->tuple.same_as(t)) { + return GetRef(op); + } else { + return TupleGetItem(t, op->index, op->span); + } +} + +Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const RuntimeDepShapeNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const ExternFuncNode* op) { return GetRef(op); } + +Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + all_blocks_unchanged &= block.same_as(new_block); + } + + Expr body = this->VisitExpr(op->body); + + if (all_blocks_unchanged && body.same_as(op->body)) { + return GetRef(op); + } else { + return SeqExpr(blocks, body); + } +} + +BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { + Array bindings; + if (const auto* node = block.as()) { + for (auto binding : node->bindings) { + if (auto var_binding = binding.as()) { + Expr new_value = this->VisitExpr(var_binding->value); + bindings.push_back(VarBinding(var_binding->var, new_value)); + } else if (auto match_shape_binding = binding.as()) { + Expr new_value = this->VisitExpr(match_shape_binding->value); + bindings.push_back( + MatchShape(new_value, match_shape_binding->pattern, match_shape_binding->var)); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + } + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + + if (block.as()) { + return DataflowBlock(bindings); + } else { + return BindingBlock(bindings); + } +} + +Type ExprMutatorBase::VisitType(const Type& t) { return t; } + +// ================== +// ExprMutator + +Expr ExprMutator::VisitExpr(const Expr& expr) { + return builder_->Normalize(ExprFunctor::VisitExpr(expr)); +} + +Expr ExprMutator::VisitExpr_(const TupleNode* op) { + bool unchanged = true; + tvm::Array fields; + for (Expr field : op->fields) { + Expr new_field = this->VisitExpr(field); + fields.push_back(new_field); + unchanged &= new_field.same_as(field); + } + + if (unchanged) { + return GetRef(op); + } else { + Expr new_tuple = Tuple(fields, op->span); + return new_tuple; + } +} + +// Visit the use-site of a defined Var +Expr ExprMutator::VisitExpr_(const VarNode* op) { + auto it = var_remap_.find(op->vid); + if (it != var_remap_.end()) { + return it->second; + } + + // default case return self. + return GetRef(op); +} + +// Visit the use-site of a defined DataflowVar +Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { + auto it = var_remap_.find(op->vid); + if (it != var_remap_.end()) { + return it->second; + } + + // default case return self. + return GetRef(op); +} + +Expr ExprMutator::VisitExpr_(const FunctionNode* op) { + tvm::Array params; + bool all_params_unchanged = true; + for (Var param : op->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + all_params_unchanged &= param.same_as(new_param); + } + + Type ret_type = this->VisitType(op->ret_type); + Expr ret_shape = this->VisitExpr(op->ret_shape); + Expr body = this->VisitWithNewScope(op->body); + + if (all_params_unchanged && ret_type.same_as(op->ret_type) && body.same_as(op->body) && + ret_shape.same_as(op->ret_shape)) { + return GetRef(op); + } else { + return Function(params, body, ret_type, ret_shape, op->attrs); + } +} + +Expr ExprMutator::VisitExpr_(const IfNode* op) { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitWithNewScope(op->true_branch); + Expr false_b = this->VisitWithNewScope(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } +} + +Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + all_blocks_unchanged &= block.same_as(new_block); + } + + builder_->BeginBindingBlock(); + Expr body = this->VisitExpr(op->body); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + blocks.push_back(prologue); + all_blocks_unchanged = false; + } + + if (all_blocks_unchanged && body.same_as(op->body)) { + return GetRef(op); + } else { + return SeqExpr(blocks, body); + } +} + +void ExprMutator::VisitBinding_(const VarBindingNode* binding) { + Expr new_value = this->VisitExpr(binding->value); + Var new_var = this->VisitVarDef(binding->var); + + auto emit = [this](VarBinding b) { + if (this->builder_->CurrentBlockIsDataFlow() && !b->var.as()) { + this->builder_->EmitOutput(b); + } else { + this->builder_->Emit(b); + } + }; + + // FIXME(@altanh): try to clean up all the fast paths and ty/shape infer, it's getting unwieldy + // if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + // // no-op if there is no change + // emit(GetRef(binding)); + // return; + // } + + // fast path: reemit binding if nothing changes + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + emit(GetRef(binding)); + return; + } + + Var temp = WithShapeAndType(new_var, new_value->shape_, new_value->checked_type_); + if (!temp.same_as(new_var)) { + new_var = temp; + this->var_remap_[binding->var->vid] = new_var; + } + + emit(VarBinding(new_var, new_value)); +} + +void ExprMutator::VisitBinding_(const MatchShapeNode* binding) { + Expr new_value = this->VisitExpr(binding->value); + Expr new_pattern = this->VisitExpr(ShapeExpr(binding->pattern)); + + Var new_var; + if (binding->var.defined()) { + // in the case of `x = R.match_shape(val, pattern)`, we want `x` to directly get `pattern` as + // the shape when `val` is a tensor. + Optional new_shape; + if (new_value->checked_type_.defined() && new_value->checked_type_.as()) { + new_shape = new_pattern; + } + new_var = this->VisitVarDef(binding->var); + Var temp = WithShapeAndType(new_var, new_shape, new_value->checked_type_); + if (!temp.same_as(new_var)) { + new_var = temp; + this->var_remap_[binding->var->vid] = new_var; + } + } + + // reemit old binding if nothing changes + if (new_value.same_as(binding->value) && new_pattern.same_as(binding->pattern)) { + if (!binding->var.defined() || (binding->var.defined() && new_var.same_as(binding->var))) { + builder_->EmitMatchShape(GetRef(binding)); + return; + } + } + + // TODO(@altanh, @yuchen): shape and type inference here too... + // TODO(@yuchen): when value's shape/type changed, create new var + // TODO(@yuchen): group the can prove shape/type logic and replace var into a function + builder_->EmitMatchShape( + MatchShape(new_value, Downcast(new_pattern)->values, new_var)); +} + +BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); +} + +BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) { + builder_->BeginDataflowBlock(); + for (auto binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); +} + +Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { + bool shape_unchanged = true; + Expr new_shape; + if (var->shape_) { + new_shape = this->VisitExpr(Downcast(var->shape_.value())); + shape_unchanged &= new_shape.same_as(var->shape_); + } + + if (shape_unchanged) { + return GetRef(var); + } else { + Var new_var = DataflowVar(var->vid, NullOpt, var->checked_type_, var->span); + UpdateShape(new_var, new_shape); + + this->var_remap_[var->vid] = new_var; + return new_var; + } +} + +Var ExprMutator::VisitVarDef_(const VarNode* var) { + bool shape_unchanged = true; + Expr new_shape; + if (var->shape_) { + new_shape = this->VisitExpr(Downcast(var->shape_.value())); + shape_unchanged &= new_shape.same_as(var->shape_); + } + + if (shape_unchanged) { + return GetRef(var); + } else { + Var new_var = Var(var->vid, NullOpt, var->checked_type_, var->span); + UpdateShape(new_var, new_shape); + + this->var_remap_[var->vid] = new_var; + return new_var; + } +} + +void ExprMutator::VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } +} + +BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) { + BindingBlock ret; + if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return ret; +} + +Var ExprMutator::VisitVarDef(const Var& var) { + Var ret; + if (const auto* node = var.as()) { + ret = VisitVarDef_(node); + } else if (const auto* node = var.as()) { + ret = VisitVarDef_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + } + return ret; +} + +Expr ExprMutator::VisitWithNewScope(const Expr& expr) { + builder_->BeginBindingBlock(); + Expr ret = this->VisitExpr(expr); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + ret = SeqExpr({prologue}, ret); + } + return ret; +} + +Optional ExprMutator::LookupBinding(const Var& var) { return builder_->LookupBinding(var); } + +Var ExprMutator::WithShapeAndType(Var var, Optional shape, Type type) { + // shape/type changes if it goes from defined -> undefined or the other way, hence xor + bool shape_changed = var->shape_.operator bool() ^ shape.operator bool(); + shape_changed |= var->shape_ && shape && + !builder_->CanProveShapeEqual(Downcast(var->shape_.value()), + Downcast(shape.value())); + + bool type_changed = var->checked_type_.defined() ^ type.defined(); + type_changed |= var->checked_type_.defined() && type.defined() && + !StructuralEqual()(var->checked_type_, type); + + if (shape_changed || type_changed) { + Var new_var = var.as() ? DataflowVar(var->vid, NullOpt, NullOpt, var->span) + : Var(var->vid, NullOpt, NullOpt, var->span); + UpdateShape(new_var, var->shape_); + UpdateType(new_var, var->checked_type_); + var = new_var; + } + + if (shape_changed) { + var->shape_ = shape; + } + + if (type_changed) { + var->checked_type_ = type; + } + + return var; +} + +TVM_REGISTER_GLOBAL("relax.MakePyExprVisitor").set_body_typed(PyExprVisitor::MakePyExprVisitor); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitExpr") + .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { visitor->VisitExpr(expr); }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBinding") + .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { + visitor->VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBindingBlock") + .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { + visitor->VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitVarDef") + .set_body_typed([](PyExprVisitor visitor, const Var& var) { visitor->VisitVarDef(var); }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitExpr") + .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { + visitor->ExprVisitor::VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") + .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { + visitor->ExprVisitor::VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") + .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { + visitor->ExprVisitor::VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") + .set_body_typed([](PyExprVisitor visitor, const Var& var) { + visitor->ExprVisitor::VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitType") + .set_body_typed([](PyExprVisitor visitor, const Type& type) { + visitor->ExprVisitor::VisitType(type); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitSpan") + .set_body_typed([](PyExprVisitor visitor, const Span& span) { + visitor->ExprVisitor::VisitSpan(span); + }); + +TVM_REGISTER_GLOBAL("relax.MakePyExprMutator").set_body_typed(PyExprMutator::MakePyExprMutator); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExpr") + .set_body_typed([](PyExprMutator visitor, const Expr& expr) { + return visitor->VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBinding") + .set_body_typed([](PyExprMutator visitor, const Binding& binding) { + visitor->VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBindingBlock") + .set_body_typed([](PyExprMutator visitor, const BindingBlock& block) { + return visitor->VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitVarDef") + .set_body_typed([](PyExprMutator visitor, const Var& var) { + return visitor->VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitExpr") + .set_body_typed([](PyExprMutator visitor, const Expr& expr) { + return visitor->ExprMutator::VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") + .set_body_typed([](PyExprMutator visitor, const Binding& binding) { + return visitor->ExprMutator::VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") + .set_body_typed([](PyExprMutator visitor, const BindingBlock& block) { + return visitor->ExprMutator::VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") + .set_body_typed([](PyExprMutator visitor, const Var& var) { + return visitor->ExprMutator::VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitType") + .set_body_typed([](PyExprMutator visitor, const Type& type) { + return visitor->ExprMutator::VisitType(type); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExprPostOrder") + .set_body_typed([](PyExprMutator visitor, const Expr& expr) { + return visitor->VisitExprPostOrder(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitWithNewScope") + .set_body_typed([](PyExprMutator visitor, const Expr& expr) { + return visitor->VisitWithNewScope(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorLookupBinding") + .set_body_typed([](PyExprMutator visitor, const Var& var) { + return visitor->LookupBinding(var); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorWithShapeAndType") + .set_body_typed([](PyExprMutator visitor, Var var, Optional shape, Type type) { + return visitor->WithShapeAndType(var, shape, type); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") + .set_body_typed([](PyExprMutator visitor, Id id, Var var) { + return visitor->var_remap_[id] = var; + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorGetVarRemap") + .set_body_typed([](PyExprMutator visitor, Id id) { return visitor->var_remap_[id]; }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc new file mode 100644 index 0000000000..5b2c01e54c --- /dev/null +++ b/src/relax/ir/transform.cc @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/ir/transform.cc + * \brief Relax specific transformation passes. + */ +#include +#include +#include +#include +#include +#include +#include +namespace tvm { +namespace relax { +namespace transform { + +TVM_REGISTER_PASS_CONFIG_OPTION("relax.fallback_device_type", IntImm); + +// TODO(@yuchen): will need to dedup with FunctionPass in Relay when we upstream +class FunctionPass; + +/*! + * \brief Function-level passes are used to implement various global + * optimizations for a given Relax IRModule. It fetches one function at a time + * from the function list in the IRModule for optimization. + * + * Note that the scope of passes at this level is a Relax function. Therefore, + * we cannot add or delete a function through these passes as they are not aware + * of the global information. + */ +class FunctionPassNode : public tvm::transform::PassNode { + public: + /* \brief The pass meta data.*/ + PassInfo pass_info; + + /*! \brief The packed pass function sketches the real optimization. For + * instance, we can implement a pass that works on a Relax function as a + * `pass_func` and let it run on a given IRModule. The same `pass_func` will + * then be applied on each function in the IRModule. + */ + runtime::TypedPackedFunc pass_func; + + FunctionPassNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + + /*! + * \brief Run a function pass on given pass context. + * + * \param mod The IRModule that an optimization pass is applied on. + * \param pass_ctx The context that an optimization pass executes on. + * + * \return Return the updated IRModule. + */ + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; + + /*! + * \brief Get the pass information/meta data. + */ + PassInfo Info() const override { return pass_info; } + + static constexpr const char* _type_key = "relax.FunctionPass"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); + + private: + /* + * \brief Check if a function should be skipped for optimization. + * + * \param func The target function to be checked. + * + * \return Return true if the function will be skipped, otherwise false. + */ + bool SkipFunction(const Function& func) const; +}; + +class FunctionPass : public Pass { + public: + /*! + * \brief The constructor + * \param pass_func The packed function which implements a pass. + * \param pass_info The pass info. + */ + TVM_DLL FunctionPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info); + + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); +}; + +FunctionPass::FunctionPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + auto n = make_object(); + n->pass_func = std::move(pass_func); + n->pass_info = std::move(pass_info); + data_ = std::move(n); +} + +// Perform IRModule -> IRModule optimizations at the Function level. +IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + DiagnosticContext previous = DiagnosticContext::Default(mod); + + if (pass_ctx->diag_ctx) { + DiagnosticContext tmp = pass_ctx->diag_ctx.value(); + pass_ctx->diag_ctx = previous; + previous = tmp; + } else { + pass_ctx->diag_ctx = previous; + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + const PassInfo& pass_info = Info(); + + ICHECK(mod.defined()); + + VLOG_CONTEXT << pass_info->name; + VLOG(0) << "Executing function pass with opt level: " << pass_info->opt_level; + VLOG(1) << "Input module:" << std::endl << PrettyPrint(mod); + + IRModule updated_mod = mod->ShallowCopy(); + + std::vector > updates; + for (const auto& it : updated_mod->functions) { + // only picks up relax::Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + auto updated_func = SkipFunction(func) ? func : pass_func(func, updated_mod, pass_ctx); + updates.push_back({it.first, updated_func}); + } + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block, this is a bug."; + + pass_ctx->diag_ctx.value().Render(); + pass_ctx->diag_ctx = previous; + + VLOG(1) << "Output module:" << std::endl << PrettyPrint(updated_mod); + + return updated_mod; +} + +bool FunctionPassNode::SkipFunction(const Function& func) const { + // TODO(@yuchen): will need to revisit in the future + return (func->GetAttr(relay::attr::kCompiler).defined()) || + func->GetAttr(relay::attr::kSkipOptimization, 0) != 0; +} + +Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); + return FunctionPass(pass_func, pass_info); +} + +TVM_REGISTER_NODE_TYPE(FunctionPassNode); + +TVM_REGISTER_GLOBAL("relax.transform.MakeFunctionPass") + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return FunctionPass(pass_func, pass_info); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name << " at the optimization level " + << info->opt_level; + }); + +class DataflowBlockPass; + +/*! + * \brief DataflowBlock-level passes are used to implement various dataflow block + * optimizations for a given Relax IRModule. It fetches one dataflow block at a time + * from the functions in an IRModule, and yields a rewritten DataflowBlock. + * + * Note that the scope of passes at this level is a Relax DataflowBlock. Therefore, + * we cannot modify the global scope Vars and symbolic shape Vars defined inside the dataflow block. + */ +class DataflowBlockPassNode : public tvm::transform::PassNode { + public: + /* \brief The pass meta data.*/ + PassInfo pass_info; + + /*! \brief The packed pass function sketches the real optimization. For + * instance, we can implement a pass that works on a Relax DataflowBlock as a + * `pass_func` and let it run on a given IRModule. The same `pass_func` will + * then be applied on each DataflowBlock in the IRModule. + */ + runtime::TypedPackedFunc pass_func; + + DataflowBlockPassNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } + + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; + + PassInfo Info() const override { return pass_info; } + + static constexpr const char* _type_key = "relax.DataflowBlockPass"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockPassNode, PassNode); +}; + +/*! \brief Helper to apply the passed function to dataflow blocks.*/ +class DataflowBlockMutator : public ExprMutator { + public: + DataflowBlockMutator( + runtime::TypedPackedFunc pass_func, + IRModule mod, PassContext pass_ctx) + : pass_func_(pass_func), mod_(mod), pass_ctx_(pass_ctx) {} + + /*! + * \brief Rewrite the DataflowBlockNode with pass_func_ + * + * This function will check that there are no rewrites of the global scope Vars + * and symbolic shape Vars defined inside the dataflow block. + */ + BindingBlock VisitBindingBlock_(const DataflowBlockNode* n) final { + // collect Global Scope Vars and Symbolic Vars inside the DataflowBlock + Map global_scope_vars; + Map symbolic_vars; + for (const Binding& binding : n->bindings) { + Var var; + if (const auto* node = binding.as()) { + var = node->var; + } else if (const auto* node = binding.as()) { + var = node->var; + for (PrimExpr expr : node->pattern) { + if (const tir::VarNode* sym_var = expr.as()) { + symbolic_vars.Set(sym_var->name_hint, Downcast(expr)); + } + } + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + if (!var.as()) { + global_scope_vars.Set(var->name_hint(), var); + } + } + + // apply pass_func_ to the DataflowBlock + DataflowBlock block = GetRef(n); + DataflowBlock updated_block = pass_func_(block, mod_, pass_ctx_); + + // raise error if there are updates of recorded Global Scope Vars and Symbolic Vars + for (const Binding& binding : updated_block->bindings) { + Var var; + if (const auto* node = binding.as()) { + var = node->var; + } else if (const auto* node = binding.as()) { + var = node->var; + for (PrimExpr expr : node->pattern) { + if (const tir::VarNode* sym_var = expr.as()) { + if (symbolic_vars.count(sym_var->name_hint) > 0) { + tir::Var old_var = symbolic_vars[sym_var->name_hint]; + ICHECK(expr.same_as(old_var)) + << "Error: DataflowBlock Pass should not rewrite any Symbolic Var."; + symbolic_vars.erase(sym_var->name_hint); + } + } + } + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + if (!var.as() && global_scope_vars.count(var->name_hint()) > 0) { + ICHECK(var.same_as(global_scope_vars[var->name_hint()])) + << "Error: DataflowBlock Pass should not rewrite any GlobalScope Var."; + global_scope_vars.erase(var->name_hint()); + } + } + ICHECK(global_scope_vars.empty() && symbolic_vars.empty()) + << "Error: DataflowBlock Pass should not delete any GlobalScope/Symbolic Var."; + + return std::move(updated_block); + } + + private: + runtime::TypedPackedFunc pass_func_; + IRModule mod_; + PassContext pass_ctx_; +}; + +class DataflowBlockPass : public Pass { + public: + /*! + * \brief The constructor + * \param pass_func The packed function which implements a pass. + * \param pass_info The pass info. + */ + TVM_DLL DataflowBlockPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info); + + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockPass, Pass, DataflowBlockPassNode); +}; + +DataflowBlockPass::DataflowBlockPass( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + auto n = make_object(); + n->pass_func = std::move(pass_func); + n->pass_info = std::move(pass_info); + data_ = std::move(n); +} + +// Perform IRModule -> IRModule transformations at the DataflowBlock level. +IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + DiagnosticContext previous = DiagnosticContext::Default(mod); + + if (pass_ctx->diag_ctx) { + DiagnosticContext tmp = pass_ctx->diag_ctx.value(); + pass_ctx->diag_ctx = previous; + previous = tmp; + } else { + pass_ctx->diag_ctx = previous; + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block, this is a bug."; + + const PassInfo& pass_info = Info(); + + ICHECK(mod.defined()); + + VLOG_CONTEXT << pass_info->name; + VLOG(0) << "Executing DataflowBlock pass with opt level: " << pass_info->opt_level; + VLOG(1) << "Input module:" << std::endl << PrettyPrint(mod); + + IRModule updated_mod = mod->ShallowCopy(); + + DataflowBlockMutator dataflow_block_mutator(pass_func, updated_mod, pass_ctx); + std::vector > updates; + for (const auto& it : updated_mod->functions) { + // only picks up relax::Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + Function updated_func = Downcast(dataflow_block_mutator.VisitExpr(func)); + updates.push_back({it.first, updated_func}); + } + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + + ICHECK(pass_ctx->diag_ctx) + << "The diagnostic context was set at the top of this block this is a bug."; + + pass_ctx->diag_ctx.value().Render(); + pass_ctx->diag_ctx = previous; + + VLOG(1) << "Output module:" << std::endl << PrettyPrint(updated_mod); + + return updated_mod; +} + +Pass CreateDataflowBlockPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); + return DataflowBlockPass(pass_func, pass_info); +} + +TVM_REGISTER_NODE_TYPE(DataflowBlockPassNode); + +TVM_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass") + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return DataflowBlockPass(pass_func, pass_info); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run DataflowBlock pass: " << info->name << " at the optimization level " + << info->opt_level; + }); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc new file mode 100644 index 0000000000..d7304e7ae6 --- /dev/null +++ b/src/relax/ir/type.cc @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/type.cc + * \brief Relax's type system AST nodes throughout the IR. + */ +#include +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(ShapeTypeNode); + +ShapeType::ShapeType(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](Span span) { return ShapeType(span); }); + +ObjectType::ObjectType(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ObjectTypeNode); + +TVM_REGISTER_GLOBAL("relax.ObjectType").set_body_typed([](Span span) { return ObjectType(span); }); + +DynTensorType::DynTensorType(int ndim, DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->ndim = std::move(ndim); + n->dtype = std::move(dtype); + n->span = span; + data_ = std::move(n); +} + +DynTensorType DynTensorType::CreateUnknownNDim(DataType dtype, Span span) { + ObjectPtr n = make_object(); + n->ndim = -1; + n->dtype = std::move(dtype); + n->span = std::move(span); + return DynTensorType(std::move(n)); +} + +TVM_REGISTER_NODE_TYPE(DynTensorTypeNode); + +TVM_REGISTER_GLOBAL("relax.DynTensorType").set_body_typed([](int ndim, DataType dtype, Span span) { + return DynTensorType(ndim, dtype, span); +}); + +DimType::DimType(Span span) { + ObjectPtr n = make_object(); + n->span = span; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DimTypeNode); + +TVM_REGISTER_GLOBAL("relax.DimType").set_body_typed([](Span span) { return DimType(span); }); + +/*! + * \brief Utility class for generic type dispatching: + * VisitType dispatches on the base type and checks if the derived type is a subtype of the base + * type. + */ +class BaseTypeChecker : public TypeFunctor { + public: + explicit BaseTypeChecker(const Type& derived) : derived_{derived} {} + + bool VisitType_(const ShapeTypeNode* base) final { + if (derived_.as()) { + return true; + } + return false; + } + bool VisitType_(const ObjectTypeNode* base) final { return true; } + + bool VisitType_(const DynTensorTypeNode* base) final { + if (auto derived_tensor = derived_.as()) { + if (base->IsUnknownNdim() || base->ndim == derived_tensor->ndim) { + if (base->IsUnknownDtype() || base->dtype == derived_tensor->dtype) { + return true; + } + } + } + return false; + } + + bool VisitType_(const TupleTypeNode* base) final { + if (auto derived_tuple = derived_.as()) { + if (base->fields.size() != derived_tuple->fields.size()) { + return false; + } + + for (size_t i = 0; i < base->fields.size(); ++i) { + if (!IsBaseOf(base->fields[i], derived_tuple->fields[i])) { + return false; + } + } + return true; + } + return false; + } + + bool VisitType_(const FuncTypeNode* base) final { + if (auto derived_func = derived_.as()) { + if (base->arg_types.size() != derived_func->arg_types.size()) { + return false; + } + for (size_t i = 0; i < base->arg_types.size(); ++i) { + if (!IsBaseOf(base->arg_types[i], derived_func->arg_types[i])) { + return false; + } + } + if (!IsBaseOf(base->ret_type, derived_func->ret_type)) { + return false; + } + return true; + } + return false; + } + + private: + Type derived_; +}; + +bool IsBaseOf(const Type& base, const Type& derived) { + BaseTypeChecker visitor(derived); + return visitor.VisitType(base); +} + +TVM_REGISTER_GLOBAL("relax.IsBaseOf").set_body_typed([](const Type& base, const Type& derived) { + return IsBaseOf(base, derived); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc new file mode 100644 index 0000000000..163f9ddbf5 --- /dev/null +++ b/src/relax/op/op.cc @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include "op_common.h" + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(AllocTensorAttrs); +TVM_REGISTER_NODE_TYPE(VMAllocStorageAttrs); +TVM_REGISTER_NODE_TYPE(VMAllocTensorAttrs); +TVM_REGISTER_NODE_TYPE(ShapeHeapAttrs); + +bool EqualConstInt(const PrimExpr& lhs, int64_t value) { + if (const int64_t* pvalue = tir::as_const_int(lhs)) { + return pvalue[0] == value; + } + return false; +} + +bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { + PrimExpr diff = lhs - rhs; + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + tvm::arith::Analyzer ana; + diff = ana.Simplify(diff); + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + return false; +} + +Type ReturnVoidType(const Call& call, DiagnosticContext diag_ctx) { return VoidType(); } + +Type ReturnObjectType(const Call& call, DiagnosticContext diag_ctx) { return ObjectType(); } + +Type ReturnShapeType(const Call& call, DiagnosticContext diag_ctx) { return ShapeType(); } + +// call_tir + +Optional InferShapeCallTIR(const Call& call, DiagnosticContext diag_ctx) { + Expr output_shape = call->args[2]; + return output_shape; +} + +Type InferTypeArg(const Call& call, DiagnosticContext diag_ctx) { + if (call->type_args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "type_args should have exact 1 output type."); + } + Type output_type = call->type_args[0]; + return output_type; +} + +RELAY_REGISTER_OP("relax.call_tir") + .set_num_inputs(4) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") + .add_argument("output_shape", "Expr", "The output shape.") + .add_argument("packed_ints", "Expr", + "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " + "args if unused") + .set_attr("FInferShape", InferShapeCallTIR) + .set_attr("FInferType", InferTypeArg); + +Expr MakeCallTIR(Expr func, Tuple args, Expr output_shape, Type output_type, + Optional packed_ints) { + static const Op& op = Op::Get("relax.call_tir"); + Call call; + if (!packed_ints) { + // don't use additional optional argument + call = Call(op, {func, args, output_shape}, {}, {output_type}); + } else { + call = Call(op, {func, args, output_shape, packed_ints.value()}, {}, {output_type}); + } + return call; +} + +TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); + +// print +TVM_REGISTER_NODE_TYPE(PrintAttrs); + +RELAY_REGISTER_OP("relax.print") + .set_attrs_type() + .set_num_inputs(-1) + .add_argument("vals", "Array", "Values to print.") + .set_attr("FInferType", ReturnVoidType) + .set_attr("FCallPacked", "relax.run.print"); + +Expr MakePrint(Array vals, std::string format) { + auto attrs = make_object(); + attrs->format = format; + static const Op& op = Op::Get("relax.print"); + return Call(op, vals, Attrs(attrs)); +} + +TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint); + +// assert_op + +// can't actually name it assert or else Python will consider it a syntax error + +Type InferAssertType(const Call& call, DiagnosticContext diag_ctx) { + // Ensure that the condition argument is a boolean scalar. + // Also permitted is a tensor with unknown shape and unknown dtype + // (checked dynamically in that case). Returns void. + if (call->args.size() < 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Assert must have at least one argument (the condition)."); + } + Type arg_type = call->args[0]->checked_type(); + if (!IsBoolScalarType(arg_type)) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The argument to assert must be a boolean scalar type, but received " + << arg_type); + } + return VoidType(); +} + +TVM_REGISTER_NODE_TYPE(AssertOpAttrs); + +RELAY_REGISTER_OP("relax.assert_op") + .set_attrs_type() + .set_num_inputs(-1) + .add_argument("vals", "Array", + "The first value is used as the assertion condition. The others are used as " + "format arguments if there is an error.") + .set_attr("FInferType", InferAssertType) + .set_attr("FCallPacked", "relax.run.assert_op"); + +Expr MakeAssertOp(Expr condition, Array vals, std::string format) { + auto attrs = make_object(); + attrs->format = format; + static const Op& op = Op::Get("relax.assert_op"); + Array args = {condition}; + for (auto val : vals) { + args.push_back(val); + } + return Call(op, args, Attrs(attrs)); +} + +TVM_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp); + +// make_closure + +RELAY_REGISTER_OP("relax.make_closure") + .set_num_inputs(2) + .add_argument("func", "Expr", "The closure.") + .add_argument("args", "Tuple", "The captured variables.") + .set_attr("FInferType", ReturnObjectType); + +Expr MakeClosure(Expr func, Tuple args) { + static const Op& op = Op::Get("relax.make_closure"); + return Call(op, {func, args}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure); + +// invoke_closure + +RELAY_REGISTER_OP("relax.invoke_closure") + .set_num_inputs(2) + .add_argument("closure", "Expr", "The VMClosure.") + .add_argument("args", "Tuple", "The captured variables.") + .set_attr("FInferType", InferTypeArg); + +Expr InvokeClosure(Expr closure, Tuple args) { + static const Op& op = Op::Get("relax.invoke_closure"); + return Call(op, {closure, args}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure); + +// shape_of + +RELAY_REGISTER_OP("relax.shape_of") + .set_num_inputs(1) + .add_argument("input", "Expr", "The input expression") + .set_attr("FInferType", ReturnShapeType); + +Expr MakeShapeOf(Expr expr) { + static const Op& op = Op::Get("relax.shape_of"); + return Call(op, {expr}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf); + +// alloc_tensor + +Optional InferShapeAllocTensor(const Call& call, DiagnosticContext diag_ctx) { + return call->args[0]; +} + +Type InferTypeAllocTensor(const Call& call, DiagnosticContext diag_ctx) { + auto attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "must be AllocTensorAttrs, but got " << call->attrs->GetTypeKey(); + auto output_shape = call->args[0].as(); + ICHECK(output_shape != nullptr) << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey(); + return DynTensorType(output_shape->values.size(), attrs->dtype); +} + +RELAY_REGISTER_OP("relax.builtin.alloc_tensor") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .set_attr("FInferShape", InferShapeAllocTensor) + .set_attr("FInferType", InferTypeAllocTensor); + +Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index) { + auto attrs = make_object(); + attrs->dtype = std::move(dtype); + attrs->runtime_device_index = std::move(runtime_device_index); + static const Op& op = Op::Get("relax.builtin.alloc_tensor"); + return Call(op, {shape}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); + +// vm alloc_storage + +RELAY_REGISTER_OP("relax.vm.builtin.alloc_storage") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("size", "Expr", "The size of the storage to allocate.") + .set_attr("FInferType", ReturnObjectType); + +Expr MakeVMAllocStorage(Expr size, DataType dtype, int64_t runtime_device_index) { + auto attrs = make_object(); + attrs->dtype = std::move(dtype); + attrs->runtime_device_index = std::move(runtime_device_index); + static const Op& op = Op::Get("relax.vm.builtin.alloc_storage"); + return Call(op, {size}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.builtin.alloc_storage").set_body_typed(MakeVMAllocStorage); + +// vm alloc_tensor + +Optional InferShapeVMAllocTensor(const Call& call, DiagnosticContext diag_ctx) { + return call->args[1]; +} + +Type InferTypeVMAllocTensor(const Call& call, DiagnosticContext diag_ctx) { + auto attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "must be VMAllocTensorAttrs , but got " << call->attrs->GetTypeKey(); + if (const auto* output_shape = call->args[1].as()) { + return DynTensorType(output_shape->values.size(), attrs->dtype); + } + return DynTensorType::CreateUnknownNDim(attrs->dtype, Span()); +} + +RELAY_REGISTER_OP("relax.vm.builtin.alloc_tensor") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("storage", "Expr", "The storage to allocate the tensor to.") + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .set_attr("FInferShape", InferShapeVMAllocTensor) + .set_attr("FInferType", InferTypeVMAllocTensor); + +Expr MakeVMAllocTensor(Expr storage, Expr shape, DataType dtype, int64_t runtime_device_index) { + auto attrs = make_object(); + attrs->dtype = std::move(dtype); + attrs->runtime_device_index = std::move(runtime_device_index); + static const Op& op = Op::Get("relax.vm.builtin.alloc_tensor"); + return Call(op, {storage, shape}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.builtin.alloc_tensor").set_body_typed(MakeVMAllocTensor); + +// vm store_shape + +RELAY_REGISTER_OP("relax.vm.builtin.store_shape") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("shape", "Expr", "The shape to be stored.") + .add_argument("heap", "Expr", "The heap to store the shape.") + .set_attr("FInferType", ReturnVoidType); + +Expr MakeStoreShape(Expr shape, Expr heap, Array indices) { + auto attrs = make_object(); + attrs->indices = std::move(indices); + static const Op& op = Op::Get("relax.vm.builtin.store_shape"); + return Call(op, {shape, heap}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.builtin.store_shape").set_body_typed(MakeStoreShape); + +// vm load_shape + +RELAY_REGISTER_OP("relax.vm.builtin.load_shape") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("heap", "Expr", "The heap to load the shape from.") + .set_attr("FInferType", ReturnShapeType); + +Expr MakeLoadShape(Expr heap, Array indices) { + auto attrs = make_object(); + attrs->indices = std::move(indices); + static const Op& op = Op::Get("relax.vm.builtin.load_shape"); + return Call(op, {heap}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.builtin.load_shape").set_body_typed(MakeLoadShape); + +// vm call_tir_dyn + +RELAY_REGISTER_OP("relax.vm.call_tir_dyn") + .set_num_inputs(2) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", + "The input arguments (list of tensors and last argument is ShapeExpr)") + .set_attr("FInferType", ReturnVoidType); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h new file mode 100644 index 0000000000..7f97716ec0 --- /dev/null +++ b/src/relax/op/op_common.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file op_common.h + * \brief A set of utilities and common functionality + * for Relax ops. + */ +#ifndef TVM_RELAX_OP_OP_COMMON_H_ +#define TVM_RELAX_OP_OP_COMMON_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +bool EqualConstInt(const PrimExpr& lhs, int64_t value); + +bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs); + +/*! Quick helper macro + * - Expose a positional make function to construct the node. + * - Register op to the registry. + * + * We make the decision to always only expose positional argument. + * We will do rewrapping in the frontend to support language + * sugars such as keyword arguments and default value. + * + * \param OpName the name of registry. + */ +#define RELAX_REGISTER_BINARY_BROADCAST_OP(OpName) \ + TVM_REGISTER_GLOBAL("relax.op." OpName).set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get("relax." OpName); \ + return Call(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP("relax." OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .set_attr("FInferShape", InferShapeBinaryBroadcast) \ + .set_attr("FInferType", InferTypeBinaryBroadcast) + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_OP_COMMON_H_ diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc new file mode 100644 index 0000000000..7d60167878 --- /dev/null +++ b/src/relax/op/tensor/binary.cc @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file binary.cc + * \brief binary broadcast operators. + */ + +#include "binary.h" + +namespace tvm { +namespace relax { + +RELAX_REGISTER_BINARY_BROADCAST_OP("add") + .describe("Elementwise add with broadcasting") + .set_support_level(1); + +RELAX_REGISTER_BINARY_BROADCAST_OP("multiply") + .describe("Elementwise multiply with broadcasting") + .set_support_level(1); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h new file mode 100644 index 0000000000..241bd856e8 --- /dev/null +++ b/src/relax/op/tensor/binary.h @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file binary.h + * \brief shape and type deduction for binary broadcast operators. + */ + +#ifndef TVM_RELAX_OP_TENSOR_BINARY_H_ +#define TVM_RELAX_OP_TENSOR_BINARY_H_ + +#include +#include + +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +Optional InferShapeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Binary broadcast op should have 2 arguments"); + } + Expr lhs_shape = call->args[0]->shape(); + Expr rhs_shape = call->args[1]->shape(); + auto* s0 = lhs_shape.as(); + auto* s1 = rhs_shape.as(); + if (s0 && s1) { + std::vector output_shape; + size_t ndim0 = s0->values.size(); + size_t ndim1 = s1->values.size(); + size_t i = 1; + for (; i <= std::min(ndim0, ndim1); ++i) { + PrimExpr dim0 = s0->values[ndim0 - i]; + PrimExpr dim1 = s1->values[ndim1 - i]; + if (EqualConstInt(dim0, 1)) { + output_shape.push_back(dim1); + } else if (EqualConstInt(dim1, 1)) { + output_shape.push_back(dim0); + } else if (EqualCheck(dim0, dim1)) { + output_shape.push_back(dim0); + } else { + // defer the computation of output shapes to runtime + // e.g., broadcast Tensor([m, n]), Tensor([k]) -> defer to runtime + return Call(ExternFunc(String("vm.binary_broadcast_shape_infer")), + {call->args[0], call->args[1]}, {}, {}); + } + } + size_t max_ndim = std::max(ndim0, ndim1); + auto& longer_shape = (ndim0 > ndim1) ? s0 : s1; + for (; i <= max_ndim; ++i) { + output_shape.push_back(longer_shape->values[max_ndim - i]); + } + return ShapeExpr(Array(output_shape.rbegin(), output_shape.rend())); + } else { + return NullOpt; + } +} + +Type InferTypeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Binary broadcast op should have 2 arguments"); + } + Type lhs_type = call->args[0]->checked_type(); + Type rhs_type = call->args[1]->checked_type(); + auto* t0 = lhs_type.as(); + auto* t1 = rhs_type.as(); + if (!t0 || !t1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Both lhs and rhs should be DynTensor for broadcasting, but got " + << lhs_type->GetTypeKey() << " and " << rhs_type->GetTypeKey()); + } + + DataType output_dtype; + if (t0->IsUnknownDtype() || t1->IsUnknownDtype()) { + output_dtype = DataType::Void(); + } else if (t0->dtype != t1->dtype) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Data types " << t0->dtype << " and " << t1->dtype + << " must be equal for broadcasting operators"); + } else { + output_dtype = t0->dtype; + } + + int output_ndim; + if (t0->IsUnknownNdim() || t1->IsUnknownNdim()) { + output_ndim = -1; + } else { + output_ndim = std::max(t0->ndim, t1->ndim); + } + return DynTensorType(output_ndim, output_dtype); +} + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_BINARY_H_ diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc new file mode 100644 index 0000000000..b629a7daec --- /dev/null +++ b/src/relax/op/tensor/ternary.cc @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ternary.cc + * \brief ternary operators. + */ + +#include "ternary.h" + +namespace tvm { +namespace relax { + +RELAY_REGISTER_OP("relax.ewise_fma") + .set_num_inputs(3) + .add_argument("e1", "Expr", "The input expression") + .add_argument("e2", "Expr", "The input expression") + .add_argument("e3", "Expr", "The input expression") + .set_attr("FInferShape", InferShapeEwiseFMA) + .set_attr("FInferType", InferTypeEwiseFMA); + +Expr MakeEwiseFma(Expr expr1, Expr expr2, Expr expr3) { + static const Op& op = Op::Get("relax.ewise_fma"); + return Call(op, {expr1, expr2, expr3}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.ewise_fma").set_body_typed(MakeEwiseFma); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/ternary.h b/src/relax/op/tensor/ternary.h new file mode 100644 index 0000000000..4a843ccd9b --- /dev/null +++ b/src/relax/op/tensor/ternary.h @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ternary.h + * \brief shape and type deduction for ternary operators. + */ + +#ifndef TVM_RELAX_OP_TENSOR_TERNARY_H_ +#define TVM_RELAX_OP_TENSOR_TERNARY_H_ + +#include +#include + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +Optional InferShapeEwiseFMA(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 3) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "EwiseFMA op should have 3 arguments"); + } + Expr shape0 = call->args[0]->shape(); + Expr shape1 = call->args[1]->shape(); + Expr shape2 = call->args[2]->shape(); + auto* s0 = shape0.as(); + auto* s1 = shape1.as(); + auto* s2 = shape2.as(); + if (s0 && s1 && s2) { + std::vector output_shape; + size_t ndim0 = s0->values.size(); + size_t ndim1 = s1->values.size(); + size_t ndim2 = s2->values.size(); + if (ndim0 != ndim1 || ndim1 != ndim2) { + LOG(INFO) << ndim0; + LOG(INFO) << ndim1; + LOG(INFO) << ndim2; + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The 3 arguments of EwiseFMA must have the same number of dimensions"); + } + for (size_t i = 0; i < ndim0; ++i) { + PrimExpr dim0 = s0->values[i]; + PrimExpr dim1 = s1->values[i]; + PrimExpr dim2 = s2->values[i]; + if (EqualCheck(dim0, dim1) && EqualCheck(dim1, dim2)) { + output_shape.push_back(dim0); + } else { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The 3 arguments of EwiseFMA must have the same shape"); + } + } + return ShapeExpr(Array(output_shape.begin(), output_shape.end())); + } else { + return NullOpt; + } +} + +Type InferTypeEwiseFMA(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 3) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "EwiseFMA op should have 3 arguments"); + } + Type type0 = call->args[0]->checked_type(); + Type type1 = call->args[1]->checked_type(); + Type type2 = call->args[2]->checked_type(); + auto* t0 = type0.as(); + auto* t1 = type1.as(); + auto* t2 = type2.as(); + if (!t0 || !t1 || !t2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The 3 arguments of EwiseFMA should be DynTensor"); + } + + DataType output_dtype; + if (t0->IsUnknownDtype() || t1->IsUnknownDtype() || t2->IsUnknownDtype()) { + output_dtype = DataType::Void(); + } else if (t0->dtype != t1->dtype || t1->dtype != t2->dtype) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Data types " << t0->dtype << ", " << t1->dtype << ", and " << t2->dtype + << " must be equal for EwiseFMA"); + } else { + output_dtype = t0->dtype; + } + + int output_ndim; + if (t0->IsUnknownNdim() || t1->IsUnknownNdim() || t2->IsUnknownNdim()) { + output_ndim = -1; + } else { + output_ndim = t0->ndim; + } + return DynTensorType(output_ndim, output_dtype); +} + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_TERNARY_H_ diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc new file mode 100644 index 0000000000..228de3ae8c --- /dev/null +++ b/src/relax/op/tensor/unary.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unary.cc + * \brief unary operators. + */ + +#include "unary.h" + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(UniqueAttrs); + +RELAY_REGISTER_OP("relax.unique") + .describe( + "This operation returns the unique elements and the new index of each item in a given " + "tensor.") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attrs_type() + .set_attr("FInferShape", InferShapeUnique) + .set_attr("FInferType", InferTypeUnique) + .set_attr("FCallPacked", "relax.run.unique"); + +Expr MakeUnique(Expr data, bool sorted, bool return_inverse, bool return_counts, int dim) { + auto attrs = make_object(); + attrs->sorted = sorted; + attrs->return_inverse = return_inverse; + attrs->return_counts = return_counts; + attrs->dim = dim; + static const Op& op = Op::Get("relax.unique"); + return Call(op, {data}, Attrs(attrs)); +} + +TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(MakeUnique); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h new file mode 100644 index 0000000000..d033e838e1 --- /dev/null +++ b/src/relax/op/tensor/unary.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unary.h + * \brief shape and type deduction for unary operators. + */ + +#ifndef TVM_RELAX_OP_TENSOR_UNARY_H_ +#define TVM_RELAX_OP_TENSOR_UNARY_H_ + +#include +#include + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +Optional InferShapeUnique(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Unique op should have 1 argument"); + } + auto unique_attrs = call->attrs.as(); + // Only default values of these attributes are supported right now. + if (unique_attrs->return_counts || unique_attrs->return_inverse || unique_attrs->dim != -1) + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "support for return_inverse, return_counts, and dim is not implemented"); + return relax::RuntimeDepShape(call->span); +} + +Type InferTypeUnique(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Unique op should have 1 argument"); + } + auto* input_ty = call->args[0]->checked_type().as(); + if (!input_ty) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Input should be DynTensor, but got " + << call->args[0]->checked_type()->GetTypeKey()); + } + + // TODO(prakalp): Add support for return_inverse, return_counts and dim attributes. Only defaults + // are supported right now. + auto unique_attrs = call->attrs.as(); + if (unique_attrs->return_counts || unique_attrs->return_inverse || unique_attrs->dim != -1) + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "support for return_inverse, return_counts, and dim is not implemented"); + return DynTensorType(/*ndim=*/1, input_ty->dtype); +} + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_UNARY_H_ diff --git a/src/relax/transform/annotate_tir_op_pattern.cc b/src/relax/transform/annotate_tir_op_pattern.cc new file mode 100644 index 0000000000..b1c1ed29af --- /dev/null +++ b/src/relax/transform/annotate_tir_op_pattern.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/annotate_tir_op_pattern.cc + * \brief Annotate Op Pattern for TIR functions. It is a pass works on TIR PrimFuncs, + * but they are needed for relax fusion. So we put them in the relax namespace. + */ +#include +#include +#include + +namespace tvm { +namespace relax { + +tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) { + if (f->HasNonzeroAttr("op_pattern")) { + return f; + } else { + relay::OpPatternKind kind = AnalyzeOpPatternKind(f); + return WithAttr(std::move(f), "op_pattern", Integer(static_cast(kind))); + } +} + +namespace transform { + +Pass AnnotateTIROpPattern() { + auto pass_func = [=](tir::PrimFunc f, IRModule m, PassContext ctx) { + return AnnotateOpPattern(std::move(f)); + }; + return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc new file mode 100644 index 0000000000..550ba43237 --- /dev/null +++ b/src/relax/transform/bind_params.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Bind params to function by using name + * \param func Relax function + * \param params params dict + * \return Function + */ +inline Function BindParamsByName(Function func, const Map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto& name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(name_dict[name]); + } else { + name_dict[name] = arg; + } + } + + std::unordered_map bind_dict; + for (auto& kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "ValueError: Multiple args in the function have name " << kv.first; + } + bind_dict[arg] = Constant(kv.second); + } + Expr bound_expr = Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + ICHECK(ret.defined()) << "The returning type is expected to be a Relax Function." + << "\n"; + return ret; +} + +/*! + * \brief Bind params to a specific function in a module + * \param m The module + * \param func_name The name of the specific function + * \param param The param dict + * \return The module after binding params. + */ +IRModule BindParam(IRModule m, String func_name, Map param) { + IRModuleNode* new_module = m.CopyOnWrite(); + Map functions = m->functions; + for (const auto& func_pr : functions) { + if (const auto* relax_f = func_pr.second.as()) { + Optional gsymbol = relax_f->GetAttr(tvm::attr::kGlobalSymbol); + if (gsymbol.defined() && gsymbol.value() == func_name) { + Function f_after_bind = BindParamsByName(GetRef(relax_f), param); + new_module->Update(func_pr.first, f_after_bind); + } + } + } + return GetRef(new_module); +} + +namespace transform { + +Pass BindParams(String func_name, Map params) { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); }; + return CreateModulePass(pass_func, 0, "BindParams", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc new file mode 100644 index 0000000000..f72b607f41 --- /dev/null +++ b/src/relax/transform/call_tir_rewrite.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/call_tir_rewrite.cc + * \brief Perform explicit tensor allocation for call_tir. + */ +#include +#include +#include +#include +#include + +#include "../../relay/transforms/pattern_utils.h" + +namespace tvm { +namespace relax { + +// ================== +// CallTIRMutator +// Perform explicit tensor allocation for call_tir. +// Example: +// lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32") +// --> +// gv0 = rx.call("relax.builtin.alloc_tensor", [n, m], dtype="float32") +// rx.call_packed(func, x, gv0) + +class CallTIRMutator : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const CallNode* call) override { + // post-order mutation + Expr expr = VisitExprPostOrder_(call); + call = expr.as(); + + static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn"); + + if (call->op == call_tir_op) { + Array outs; + if (call->shape_) { + if (call->shape_.value()->IsInstance()) { + // single output case + ShapeExpr output_shape = Downcast(call->shape_.value()); + auto alloc_tensor_attr = make_object(); + + if (call->checked_type_.defined()) { + auto output_type = Downcast(call->checked_type_); + alloc_tensor_attr->dtype = output_type->dtype; + alloc_tensor_attr->runtime_device_index = 0; + outs.push_back(builder_->Emit( + Call(alloc_tensor_op, {output_shape}, Attrs(alloc_tensor_attr)), "alloc")); + } else { + LOG(FATAL) << "ValueError: the checked_type_ of call_tir has not populated."; + } + } else { + // multiple output case + ICHECK(call->shape_.value()->IsInstance()) + << "call_tir expects ShapeExpr or Tuple as its shape, but got " << call->shape_; + ICHECK(call->checked_type_->IsInstance()) + << "call_tir expects DynTensorType or TupleType as its checked type, but got " + << call->checked_type_; + Tuple output_shapes = Downcast(call->shape_); + TupleType output_types = Downcast(call->checked_type_); + ICHECK(output_shapes->fields.size() == output_types->fields.size()) + << "The output of call_tir should have the same amount of fields in its shape_ and " + "checked_type_"; + for (size_t i = 0; i < output_shapes->fields.size(); ++i) { + ICHECK(output_shapes->fields[i]->IsInstance()) + << "call_tir expects Tuple of ShapeExprs, but got " << output_shapes->fields[i] + << " as an element of tuple"; + ICHECK(output_types->fields[i]->IsInstance()) + << "call_tir expects TupleType of DynTensorType, but got " + << output_types->fields[i] << " as an element of TupleType"; + auto output_type = Downcast(output_types->fields[i]); + auto alloc_tensor_attr = make_object(); + alloc_tensor_attr->dtype = output_type->dtype; + alloc_tensor_attr->runtime_device_index = 0; + outs.push_back(builder_->Emit( + Call(alloc_tensor_op, {Downcast(output_shapes->fields[i])}, + Attrs(alloc_tensor_attr)), + "alloc")); + } + } + } else { + LOG(FATAL) << "ValueError: the shape of call_tir has not populated."; + } + + Array args; + if (call->args[1].as()) { + args = Downcast(call->args[1])->fields; + args.insert(args.end(), outs.begin(), outs.end()); + + if (call->args.size() == 3) { + builder_->Emit(Call(call->args[0], args), "_"); + } else { + // unpack semantics + args.push_back(call->args[3]); + builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}), "_"); + } + } else { + args = outs; + args.insert(args.begin(), call->args[1]); + builder_->Emit(Call(call->args[0], args), "_"); + } + + if (outs.size() == 1) { + return outs[0]; + } + return std::move(Tuple(outs)); + } + + return GetRef(call); + } +}; + +Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); } + +namespace transform { + +Pass CallTIRRewrite() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(CallTIRRewrite(f)); }; + return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc new file mode 100644 index 0000000000..e938f8938c --- /dev/null +++ b/src/relax/transform/canonicalize_bindings.cc @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/canonicalize.cc + * \brief Pass for simplifying modules by folding var bindings and match shape nodes. + * May include other forms of simplification in the future. + * Ideally should be used before constant folding and eliminating unused bindings. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + +class BindingCanonicalizer : public ExprMutator { + public: + BindingCanonicalizer() {} + + Expr VisitExpr_(const VarNode* op) override { + // remap first + Var v = Downcast(ExprMutator::VisitExpr_(op)); + if (!CanCanonicalizeVar(v)) { + return Downcast(v); + } + // visit again in case we need to do a substitution in the value + return ExprMutator::VisitExpr_(LookupBinding(v).as()); + } + + Expr VisitExpr_(const DataflowVarNode* op) override { + Var v = Downcast(ExprMutator::VisitExpr_(op)); + if (!CanCanonicalizeVar(v)) { + return Downcast(v); + } + return ExprMutator::VisitExpr_(LookupBinding(v).as()); + } + + void VisitBinding_(const VarBindingNode* binding) override { + // Unlike default visitor, we do not permit the checked type to change + // if the new value's checked type is different (this preserves user annotations) + Expr new_value = this->VisitExpr(binding->value); + Var new_var = this->VisitVarDef(binding->var); + + auto emit = [this](VarBinding b) { + if (this->builder_->CurrentBlockIsDataFlow() && !b->var.as()) { + this->builder_->EmitOutput(b); + } else { + this->builder_->Emit(b); + } + }; + + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + emit(GetRef(binding)); + return; + } + + emit(VarBinding(new_var, new_value)); + } + + void VisitBinding_(const MatchShapeNode* binding) override { + // If we have a trivial shape check (the shape_ of LHS and RHS is the same), + // we can canonicalize to a var binding + Expr new_value = this->VisitExpr(binding->value); + + Var new_var; + // since we do not permit the checked_type to change and don't make any changes + // to the shape pattern, there is no reason to do any more checking like in the + // original mutator + if (binding->var.defined()) { + new_var = this->VisitVarDef(binding->var); + } + + // if the LHS and RHS have the same shape_, we canonicalize to a var binding instead + if (new_var.defined() && new_value->shape_.defined() && + builder_->CanProveShapeEqual(Downcast(new_var->shape_), + Downcast(new_value->shape_))) { + builder_->Emit(VarBinding(new_var, new_value)); + return; + } + + // reemit old binding if nothing changes + if (new_value.same_as(binding->value)) { + if (!binding->var.defined() || (binding->var.defined() && new_var.same_as(binding->var))) { + builder_->EmitMatchShape(GetRef(binding)); + return; + } + } + + builder_->EmitMatchShape(MatchShape(new_value, binding->pattern, new_var)); + } + + private: + bool AnnotationsDiffer(const ObjectRef& obj1, const ObjectRef& obj2, + std::function check_eq) { + // annotations differ if one is present but not the other + // or they're both present and they differ + bool both_present = obj1.defined() && obj2.defined(); + bool neither_present = !obj1.defined() && !obj2.defined(); + return !(both_present || neither_present) || (both_present && !check_eq(obj1, obj2)); + } + + bool CanCanonicalizeVar(Var v) { + Optional value = LookupBinding(v); + // can replace only if the value is also a var + if (!value || !value.as()) { + return false; + } + Var parent_var = Downcast(value); + + // Cases when we conservatively do not unify: + // 1. checked_type_ or shape_ of the child differs from that of the parent + // In this case, we could be overriding user annotations. + // 2. If the child is a Var and the parent is a DataflowVar. + // That could result in a DataflowVar leaving the current DataflowBlock. + bool annotations_differ = + AnnotationsDiffer(v->shape_, parent_var->shape_, + [&](const ObjectRef& shape1, const ObjectRef& shape2) { + return builder_->CanProveShapeEqual(Downcast(shape1), + Downcast(shape2)); + }) || + AnnotationsDiffer(v->checked_type_, parent_var->checked_type_, + [&](const ObjectRef& type1, const ObjectRef& type2) { + return tvm::StructuralEqual()(type1, type2); + }); + bool var_to_dataflow = (!v.as() && parent_var.as()); + return !annotations_differ && !var_to_dataflow; + } +}; + +Expr CanonicalizeBindings(const Expr& e) { return BindingCanonicalizer().VisitExpr(e); } + +namespace transform { + +Pass CanonicalizeBindings() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeBindings(f)); + }; + return CreateFunctionPass(pass_func, 1, "CanonicalizeBindings", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings").set_body_typed(CanonicalizeBindings); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fail_test_rewrite.cc b/src/relax/transform/fail_test_rewrite.cc new file mode 100644 index 0000000000..003257408a --- /dev/null +++ b/src/relax/transform/fail_test_rewrite.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/fail_test_rewrite.cc + * \brief Incorrectly transform the dataflow structure as fail testcases. + */ +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Rewrite/Remove global var or symbolic var in the dataflow block.*/ +class FailTestRewriter : public ExprMutator { + using ExprMutator::VisitExpr_; + + // Rewrite/Remove specific global var + Var VisitVarDef_(const VarNode* var) override { + if (var->name_hint() == "gv_rewrite") { + Var new_var = Var("gv_rewrite", {}, {}, {}); + return std::move(new_var); + } else if (var->name_hint() == "gv_remove") { + Var new_var = Var("new_gv", {}, {}, {}); + return std::move(new_var); + } + return GetRef(var); + } + + // Rewrite/Remove specific symbolic var + Expr VisitExpr_(const ShapeExprNode* op) override { + if (op->values.size() == 2) { + tir::Var arg0 = Downcast(op->values[0]); + tir::Var new_arg0 = tir::Var(arg0->name_hint); + ShapeExpr new_op = ShapeExpr({new_arg0, op->values[1]}); + return std::move(new_op); + } else if (op->values.size() == 3) { + ShapeExpr new_op = ShapeExpr({op->values[0], op->values[1]}); + return std::move(new_op); + } + return GetRef(op); + } +}; + +BindingBlock FailTestRewrite(const BindingBlock& block) { + return FailTestRewriter().VisitBindingBlock(block); +} + +namespace transform { + +Pass FailTestRewrite() { + runtime::TypedPackedFunc pass_func = + [=](DataflowBlock block, IRModule m, PassContext pc) { + return Downcast(FailTestRewrite(block)); + }; + return CreateDataflowBlockPass(pass_func, 2, "FailTestRewrite", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FailTestRewrite").set_body_typed(FailTestRewrite); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fma_rewrite.cc b/src/relax/transform/fma_rewrite.cc new file mode 100644 index 0000000000..0ca9f7164e --- /dev/null +++ b/src/relax/transform/fma_rewrite.cc @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/fma_rewrite.cc + * \brief Perform fused multiply-add rewriting. + */ +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Rewrites the relax.add call to a relax.ewise_fma call when detecting the multiply-add + * pattern. + * + * Example: + * x0 = mul(a, b) + * z0 = add(x0, c) + * --> + * z0 = ewise_fma(a, b, c) + * + * Example 2: + * Question: do we want to support this? + * x0 = mul(a, add(k, b)) + * z0 = add(x0, c) + * --> + * lv0 = add(k, b) + * z0 = ewise_fma(a, lv0, c) + */ +class EwiseFMARewriter : public ExprMutator { + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const CallNode* call) override { + Expr expr = VisitExprPostOrder_(call); + call = expr.as(); + + static const Op& add_op = Op::Get("relax.add"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& ewise_fma_op = Op::Get("relax.ewise_fma"); + + if (call->op == add_op) { + // NOTE: assumes df block is completely SSA + // FIXME(@altanh, @yuchen): this will crash if args[0] isn't a Var + Optional value = LookupBinding(Downcast(call->args[0])); + const CallNode* mul = value.as(); + if (mul && mul->op == multiply_op) { + Call fma_call = Call(ewise_fma_op, {mul->args[0], mul->args[1], call->args[1]}, {}, {}); + return std::move(fma_call); + } + } + + return GetRef(call); + } +}; + +BindingBlock RewriteFMA(const BindingBlock& block) { + return EwiseFMARewriter().VisitBindingBlock(block); +} + +/*! \brief Performs multiply add fusion. The difference of EwiseFMARewriter and this + * EwiseFuseFMAMutator class is that this mutator generates a sub function(subgraph) whose body is a + * CallNode that calls to the relax.ewise_fma op, and rewrites the relax.add call in the main + * function to calling to the subgraph. + * + * Example: + * Before-transformation IRModule: + * def main(): + * x0 = mul(a, b) + * z0 = add(x0, c) + * --> + * After-transformation IRModule: + * def ewise_fused(x, y, z): + * return relax.ewise_fma(x, y, z) + * + * def main(): + * z0 = ewise_fused(a, b, c) + */ +class EwiseFuseFMAMutator : public ExprMutator { + public: + explicit EwiseFuseFMAMutator(IRModule mod) { mod_ = mod; } + + IRModule Transform() { + for (auto& p : mod_->functions) { + Expr func = p.second; + if (func->IsInstance()) { + func = this->VisitExpr(func); + } + builder_->AddFunction(Downcast(func), p.first->name_hint); + } + return builder_->GetContextIRModule(); + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call) override { + Expr expr = VisitExprPostOrder_(call); + call = expr.as(); + + static const Op& add_op = Op::Get("relax.add"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& ewise_fma_op = Op::Get("relax.ewise_fma"); + + if (call->op == add_op) { + Optional value = LookupBinding(Downcast(call->args[0])); + const CallNode* mul = value.as(); + if (mul && mul->op == multiply_op) { + // construct a subgraph + Var x = Var("x", Downcast(mul->args[0]->shape_), mul->args[0]->checked_type_); + Var y = Var("y", Downcast(mul->args[1]->shape_), mul->args[1]->checked_type_); + Var z = Var("z", Downcast(call->args[1]->shape_), call->args[1]->checked_type_); + Expr body = Call(ewise_fma_op, {x, y, z}); + + String func_name = "ewise_fma_fused"; + Function func = Function({x, y, z}, body, call->args[1]->checked_type_, RuntimeDepShape()); + Expr ewise_fma_fused = WithAttr(std::move(func), "global_symbol", func_name); + Expr normalized = builder_->Normalize(ewise_fma_fused); + GlobalVar global_var1 = + builder_->AddFunction(Downcast(normalized), "ewise_fma_fused"); + + // construct a call to the subgraph + Call fma_call = Call(global_var1, {mul->args[0], mul->args[1], call->args[1]}, {}, {}); + return std::move(fma_call); + } + } + + return GetRef(call); + } + + private: + IRModule mod_; +}; + +namespace transform { + +Pass RewriteFMA() { + runtime::TypedPackedFunc pass_func = + [=](DataflowBlock block, IRModule m, PassContext pc) { + return Downcast(RewriteFMA(block)); + }; + return CreateDataflowBlockPass(pass_func, 2, "RewriteFMA", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RewriteFMA").set_body_typed(RewriteFMA); + +Pass FuseFMA() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return EwiseFuseFMAMutator(mod).Transform(); }; + return CreateModulePass(pass_func, 2, "FuseFMA", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseFMA").set_body_typed(FuseFMA); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc new file mode 100644 index 0000000000..017cde9674 --- /dev/null +++ b/src/relax/transform/fold_constant.cc @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class ConstantFolder : public ExprMutator { + public: + explicit ConstantFolder(IRModule ctx_module) : ctx_module_(ctx_module) {} + + private: + /*! + * \brief Pattern match expr to a constant shape and get runtime shape tuple from it. + * \return The runtime shape tuple, or nullopt if it is not a constant shape. + */ + static Optional MatchConstShape(const Expr& expr) { + auto* shape = expr.as(); + if (!shape) return NullOpt; + + std::vector shape_values; + for (const auto v : shape->values) { + auto* ptr = v.as(); + if (!ptr) return NullOpt; + shape_values.push_back(ptr->value); + } + return runtime::ShapeTuple(shape_values.begin(), shape_values.end()); + } + + /*! + * \brief Pattern match op to constant array arguments. + * \return The constant array arguments, or nullopt if match fails. + */ + static Optional> MatchConstArrayArgs(const Array& args) { + Array res; + for (auto arg : args) { + auto* ptr = arg.as(); + if (!ptr) return NullOpt; + res.push_back(ptr->data); + } + return res; + } + + /*! + * \brief Pattern match op to a TIR function and look it up. + * \return The TIR function, or nullopt if pattern match fails. + */ + Optional MatchPrimFunc(const Expr& op) { + if (auto* ptr = op.as()) { + // NOTE: as check works for nullptr(returns null) + Optional base_func = ctx_module_->functions.Get(GetRef(ptr)); + if (auto* pfunc = base_func.as()) { + return GetRef(pfunc); + } + } + return NullOpt; + } + + /*! + * \brief Get a cached build version of func + * \return The cached func, nullopt if func cannot be built. + */ + Optional GetCachedBuild(tir::PrimFunc func) { + // TODO(tvm-team): consider another way of bulk extract and build PrimFunc once + // would be helpful for future cases where PrimFunc recursively call into each other + Target eval_cpu_target{"llvm"}; + + auto it = func_build_cache_.find(func); + if (it != func_build_cache_.end()) { + return it->second; + } + Optional build_func = NullOpt; + + try { + // Not all the primfunc can be directly built via llvm, for example, if a function is + // already scheduled to only work on GPU, we will need to skip this in the const folder for + // now + // TODO(Hongyi): further check and narrow the scope of foldable function + runtime::Module rt_module = + build(LowerPrimFunc(func, "tir_function"), eval_cpu_target, eval_cpu_target); + build_func = rt_module.GetFunction("tir_function"); + } catch (const tvm::Error& err) { + // build failure may happen in which case we skip + DLOG(WARNING) << "Build failure for function " << func << ", Error message: " << err.what(); + } + func_build_cache_[func] = build_func; + return build_func; + } + + // Try constant evaluate the function call + // if failed return NullOpt + Optional ConstEvaluateCallTIR(tir::PrimFunc tir_func, Array arr_args, + runtime::ShapeTuple shape, DataType ret_type) { + // obtain function from the cache. + Optional func = GetCachedBuild(tir_func); + if (!func) return NullOpt; + + // here the vector size has an additional + 1 because we need to put ret_tensor at the end + std::vector values(arr_args.size() + 1); + std::vector type_codes(arr_args.size() + 1); + + DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0}; + runtime::NDArray ret_tensor = runtime::NDArray::Empty(shape, ret_type, cpu_dev); + + // avoid set rvalue ref which get de-allocated later, store args in a vector + // where temp_args[i] are lvalue ref that is stable + std::vector temp_args(arr_args.begin(), arr_args.end()); + + size_t arg_offset = 0; + for (; arg_offset < arr_args.size(); ++arg_offset) { + runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset, temp_args[arg_offset]); + } + // set return value + runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset++, ret_tensor); + + TVMRetValue ret; + // invoke + func.value().CallPacked(TVMArgs(values.data(), type_codes.data(), values.size()), &ret); + return Constant(ret_tensor); + } + + Expr VisitCallTIR(Call call) { + // call_tir needs to have at least three arguments + ICHECK_GE(call->args.size(), 3); + Optional func = MatchPrimFunc(call->args[0]); + ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; + Optional> arr_args = + MatchConstArrayArgs(call->args[1].as()->fields); + Optional shape = MatchConstShape(call->args[2]); + bool output_not_tuple = call->type_args.size() == 1; + // Pattern 0: call constant function, const argument with const shape. + if (func && arr_args && shape && output_not_tuple) { + DynTensorType ret_type = Downcast(call->checked_type()); + // value_or will return value if it is not null, otherwise return or + return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_type->dtype) + .value_or(call); + } + // TODO(hongyi): support const-fold tuple outputs + return std::move(call); + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call) final { + // post-order mutation + Call post_call = Downcast(VisitExprPostOrder_(call)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + if (call->op.same_as(call_tir_op)) { + return VisitCallTIR(post_call); + } + return std::move(post_call); + } + + Expr VisitExpr_(const DataflowVarNode* op) final { + Optional opt = LookupBinding(GetRef(op)); + // `as` check checks if opt is not null and is instance of constant + if (opt.as()) { + return opt.value(); + } + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const VarNode* op) final { + Optional opt = LookupBinding(GetRef(op)); + // `as` check checks if opt is not null and is instance of constant + if (opt.as()) { + return opt.value(); + } + return ExprMutator::VisitExpr_(op); + } + + // the context module to lookup functions + IRModule ctx_module_; + // cache for function build, via structural equality + std::unordered_map, StructuralHash, StructuralEqual> + func_build_cache_; +}; + +namespace transform { + +Pass FoldConstant() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + ConstantFolder folder(m); + return Downcast(folder(f)); + }; + return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc new file mode 100644 index 0000000000..6a813d2151 --- /dev/null +++ b/src/relax/transform/fuse_ops.cc @@ -0,0 +1,797 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/fuse_ops.cc + * \brief This file contains a pass which groups bindings in a dataflow block of Relax + * functions and generate a new grouped Relax function for each group, according to the fusion + * algorithm described below. By grouping bindings into new Relax functions, we substitute the + * bindings in the function being manipulated into function calls to the new grouped function. + * + * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + */ + +#include +#include +#include + +#include "../../relay/analysis/graph_partitioner.h" +#include "../../support/arena.h" + +namespace tvm { +namespace relax { + +/* + Note on Fusing algorithm: + + The main challenge of general fusor is to handle possible diamond shape branches, + in the following graph, conv2d can be fused to elemwise add. + + conv2d + / | \ + / | \ + op op op + \ | / + \ | / + elemwise add + | + + However, at the point of conv2d we do not necessarily know that all the future paths + will merge at the elemwise add. The fusion algorithm applies post-dominator analysis. + + The immediate post-dominator of a node defined by the closest node where all the future path goes + into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm + is as follows: + + - Construct a DAG of dataflow graph for dominator analysis + - Construct a post-dominator tree which gives immediate post dominator of each node. + - Run fusion algorithm with the given post-dominator information. + + Note that, because we run analysis on a DAG, we use a single pass post-dominator + tree construction algorithm via LCA, which is simpler than the full version that handles cycles. + + The fusion algorithm traverses from each node and checks if it can be fused to its + immediate post dominator. It has to check the following things: + + - CheckPath: check all the path between a node and its immediate post-dominator + satisfies the fuse condition. + - Note that these intermediate node can already be fused with another nodes, the algorithm + will still run correctly. + - CommitFuse: mark all the nodes between source and post-dominator as the same group. + - We use an Union-Find data structure to manage the groups. +*/ + +using relay::GraphPartitioner; +using relay::IndexedForwardGraph; +using relay::OpPatternKind; +using support::LinkNode; + +constexpr uint32_t kMaxFusedOps = 256; + +TVM_REGISTER_PASS_CONFIG_OPTION("relax.FuseOps.max_depth", Integer); + +class GraphCreator : public ExprVisitor { + public: + /*! + * \brief Create a IndexedForwardGraph according to the input module. The graph will be used for + * graph partition and operator fusion. + * \param mod The module which the creation accords to + * \param arena The allocator of all the internal node objects + * \return The created IndexedForwardGraph + */ + static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) { + // Since cross-function call is not supported yet, FuseOps only serves the entry function, whose + // name is "main". + auto relax_func = Downcast(mod->Lookup("main")); + GraphCreator creator(mod, arena); + creator(relax_func); + + // The algorithm of the graph creator ensures that each created node will be added to the + // post-dfs order and will be set its op pattern. Thus we check whether all these containers + // have the same size. + size_t n_nodes = creator.graph_.node_map.size(); + ICHECK_EQ(n_nodes, creator.graph_.post_dfs_order.size()); + ICHECK_EQ(n_nodes, creator.initialized_nodes_.size()); + + return creator.graph_; + } + + private: + explicit GraphCreator(IRModule mod, support::Arena* arena) + : mod_(std::move(mod)), arena_(arena) {} + + void VisitExpr_(const FunctionNode* func) final { + for (const Var& param : func->params) { + IndexedForwardGraph::Node* param_node = CreateNode(param.get()); + // The parameter is passed in from the outside, and thus it's marked as an external reference, + // and it's pattern is `kOpaque`. + MarkAsExternRef(param_node); + SetNodePattern(param_node, OpPatternKind::kOpaque); + AddToPostDFSOrder(param_node, param.get()); + } + ExprVisitor::VisitExpr_(func); + } + + void VisitBindingBlock(const BindingBlock& block) final { + if (const auto* df_block = block.as()) { + VisitBindingBlock_(df_block); + } + // We skip ordinary binding blocks since they might be impure (with side effect or control flow) + } + + // TODO(tvm-team): how to deal with MatchShape binding here + + void VisitBinding_(const VarBindingNode* binding) final { + IndexedForwardGraph::Node* node = CreateNode(binding->var.get()); + + // If the variable is not a dataflow variable, it must be the output variable of this dataflow + // block + if (!binding->var->IsInstance()) { + this->MarkAsExternRef(node); + } + if (const auto* call = binding->value.as()) { + // Case 1. The expression is a CallNode + VisitCall(call, node); + } else if (const auto* tuple_get_item = binding->value.as()) { + // Case 2. The expression is a TupleGetItemNode + VisitTupleGetItem(tuple_get_item, node); + } else { + VisitUnsupportedNode(binding->value, node); + // Case 3. The type of the expression is not fusion-supported. + // In this case, we skip adding edges, adding an empty node into graph. + } + AddToPostDFSOrder(node, binding->var.get()); + } + + /********** Non-Leaf Expression Nodes **********/ + + void VisitCall(const CallNode* call, IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + OpPatternKind pattern = OpPatternKind::kOpaque; + Array args = call->args; + + // - If the op being called is a TIR PrimFunc, we get the function op pattern directly from the + // function attribute and visit the arguments one by one. + // - Otherwise, the pattern of the current binding variable node is set to `kOpaque`, and we + // recurse into the call expression. + const auto* op = call->op.as(); + if (op == call_tir_op_.get()) { + const GlobalVar& global_var = Downcast(call->args[0]); + tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); + + // Override args for call_tir + args = Downcast(call->args[1])->fields; + + // TODO(tvm-team): handle the shape argument (args[3]) + Optional opt_pattern = func->GetAttr("op_pattern"); + if (opt_pattern.defined()) { + pattern = static_cast(Downcast(opt_pattern)->value); + } else { + pattern = OpPatternKind::kOpaque; + } + } + // The pattern of the current binding variable node is set to the pattern of this operator. + SetNodePattern(binding_var_node, pattern); + // Visit all call args + for (const Expr& arg : args) { + ICHECK(IsLeaf(arg)); + VisitLeaf(arg, binding_var_node, pattern); + } + } + + void VisitTupleGetItem(const TupleGetItemNode* tuple_item, + IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + + SetNodePattern(binding_var_node, OpPatternKind::kInjective); + VisitLeaf(tuple_item->tuple, binding_var_node, OpPatternKind::kInjective); + } + + void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node* binding_var_node) { + ICHECK_NOTNULL(binding_var_node); + SetNodePattern(binding_var_node, OpPatternKind::kOpaque); + + auto visit_leaves = [this, &binding_var_node](const Expr& e) { + if (e->IsInstance() || e->IsInstance()) { + VisitLeaf(e, binding_var_node, OpPatternKind::kOpaque); + } + }; + PostOrderVisit(expr, visit_leaves); + } + + /********** Leaf Expression Nodes **********/ + + void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* binding_var_node, + const OpPatternKind& pattern) { + ICHECK_NOTNULL(binding_var_node); + + // Recursive visit if it's Tuple + if (const auto* tuple = leaf_expr.as()) { + for (const Expr& expr : tuple->fields) { + VisitLeaf(expr, binding_var_node, pattern); + } + return; + } + + auto it = graph_.node_map.find(leaf_expr.get()); + IndexedForwardGraph::Node* leaf_node = nullptr; + if (it != graph_.node_map.end()) { + leaf_node = it->second; + } else if (leaf_expr->IsInstance()) { + leaf_node = CreateNode(leaf_expr.get()); + // Since we never fuse constants, the pattern of the constant is set to `kOpaque`. + SetNodePattern(leaf_node, OpPatternKind::kOpaque); + AddToPostDFSOrder(leaf_node, leaf_expr.get()); + } else { + LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: " << leaf_expr + << " used before definition."; + } + AddEdge(leaf_node, binding_var_node, pattern); + } + + /********** Helper Functions **********/ + + /*! + * \brief Check whether the expression is a leaf expression + * \param expr The expression to be checked + * \return Whether the expression is a leaf expression + * \note In order to avoid too much refactor, this method is a simple copy-paste of the is-leaf + * check in "block_builder.cc". And it should be refactored in the future. + * \sa src/relax/ir/block_builder.cc + */ + static bool IsLeaf(const Expr& expr) { + // NOTE: Tuples are treated as leaf nodes for ergonomics + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as() || expr.as() || + expr.as(); + } + + /*! + * \brief Create a graph node corresponding to the input key + * \param key The object which is used to create the graph node + * \return The created graph node + * \note The node corresponding to each key is supposed to be created for only once + */ + IndexedForwardGraph::Node* CreateNode(const Object* key) { + ICHECK(graph_.node_map.find(key) == graph_.node_map.end()) + << "The node corresponding to the input key is not supposed to be created before"; + auto* node = arena_->make(); + graph_.node_map[key] = node; + return node; + } + + /*! + * \brief Append the input node to the post-dfs order of the graph + * \param node The node to be appended + * \param key The key corresponding to the node + * \note Each node is supposed to be appended to the post-dfs order for only once + */ + void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) { + auto it = graph_.node_map.find(key); + ICHECK(it != graph_.node_map.end() && it->second == node) + << "The node must have been created before adding to the post-dfs order"; + + // We only set the reference of the node when adding it to the post-dfs order. Thus, if the + // reference of a node is already set, it must have been appended to the post-dfs order. + ICHECK(node->ref == nullptr) + << "The node is not supposed to be added into the post-dfs order before"; + + node->ref = key; + node->index = graph_.post_dfs_order.size(); + graph_.post_dfs_order.push_back(node); + } + + /*! + * \brief Add an edge from the input start to the input end in the graph, with specific pattern + * \param start The start of the edge + * \param end The end of the edge + * \param pattern The pattern of this edge + */ + void AddEdge(IndexedForwardGraph::Node* start, IndexedForwardGraph::Node* end, + OpPatternKind pattern) { + auto* link = arena_->make>(); + link->value.node = end; + link->value.pattern = pattern; + start->outputs.Push(link); + } + + /*! + * \brief Mark a given node as "external reference", which means the node cannot be fused as an + * intermediate node + * \param node The graph node to be marked + */ + void MarkAsExternRef(IndexedForwardGraph::Node* node) { node->extern_ref = true; } + + /*! + * \brief Set the pattern of the input node + * \param node The graph node to be set + * \param pattern The pattern of the node + */ + void SetNodePattern(IndexedForwardGraph::Node* node, OpPatternKind pattern) { + ICHECK(initialized_nodes_.find(node) == initialized_nodes_.end()) + << "The input node is supposed to be set pattern for only once"; + initialized_nodes_.insert(node); + node->pattern = pattern; + } + + private: + /*! \brief The IRModule from which the indexed forward graph is created */ + IRModule mod_; + /*! \brief The allocator of all the internal node objects */ + support::Arena* arena_; + /*! \brief The created indexed forward graph */ + IndexedForwardGraph graph_; + /*! \brief The graph nodes whose patterns are set */ + std::unordered_set initialized_nodes_; +}; + +/*! + * \brief The ExprMutator used to create a new grouped function + * \details The workflow of this ExprMutator is: + * - The bindings in the function will be added by OperatorFusor via `AppendBinding(...)`. + * - When adding a new binding through `AppendBinding(...)`, we check whether the variables and + * constants used by the binding are defined by some previous added binding. And for the undefined + * variables and constants, we add them to the argument list and created new variables as the + * corresponding parameters. + * - When `CreateFunction()` is called, we go through each binding and update the binding with the + * new parameters. After that we wrap all bindings with a DataflowBlock and a Function. + */ +class FunctionCreator : public ExprMutator { + public: + /*! + * \brief Append a new binding to this function and possibly create new parameters for the + * function accordingly + * \param binding The binding to be appended + * \note Allowed bindings are: + * - VarBinding with value being a call node calling `relax.call_tir`. + * - VarBinding with value being a tuple-get-item node. + * // TODO(tvm-team): handle match shape + */ + void AppendBinding(const Binding& binding) { + ICHECK(!function_.defined()) + << "The `function_` is supposed to be uncreated when adding bindings"; + + if (const auto* var_binding = binding.as()) { + if (const auto* call = var_binding->value.as()) { + ICHECK(call->op == Op::Get("relax.call_tir")); + // Update the name of the function. + name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; + + const Tuple& args = Downcast(call->args[1]); + for (const Expr& arg : args->fields) { + CheckDefAndUpdateParam(arg); + } + // TODO(tvm-team): handle shape expr + } else { + const auto* tuple_item = var_binding->value.as(); + ICHECK(tuple_item != nullptr); + CheckDefAndUpdateParam(tuple_item->tuple); + } + + // Mark the binding variable as defined. + defined_vars_.insert(var_binding->var.get()); + // Set var as output true if the binding is not a dataflow variable + if (!var_binding->var->IsInstance()) { + AppendOutput(var_binding->var); + } + } else { + // TODO(tvm-team): handle match_shape + } + bindings_.push_back(binding); + } + + /*! \brief Set a var defined in the group as output. */ + void AppendOutput(const Var& var) { + ICHECK(defined_vars_.count(var.get())); + output_vars_.insert(var.get()); + } + + /*! + * \brief Create the grouped function according according to the collected bindings and parameters + * \note The created function won't be returned immediately. Tt's stored in the `function_` field. + */ + void CreateFunction() { + // Step 1. Start constructing a new dataflow block. + builder_->BeginDataflowBlock(); + + // Step 2. Visit each binding and collect outputs one by one. + Array outputs; + for (const Binding& binding : bindings_) { + const VarNode* var = nullptr; + if (const auto* var_binding = binding.as()) { + var = var_binding->var.get(); + } else if (const auto* match_shape = binding.as()) { + var = match_shape->var.get(); + } else { + ICHECK(false); + } + if (output_vars_.count(var)) { + // Case 1. It is an output binding + // We only allow VarBinding as output. + const auto* var_binding = binding.as(); + ICHECK_NOTNULL(var_binding); + Var output_var = builder_->EmitOutput(VisitExpr(var_binding->value)); + var_remap_[var_binding->var->vid] = output_var; + outputs.push_back(output_var); + } else { + // Case 2. It is an internel binding, add it to the binding list. + VisitBinding(binding); + } + } + + // Step 3. Finish constructing the new block. + BindingBlock new_block = builder_->EndBlock(); + ICHECK(!outputs.empty()) << "At least one output is required."; + Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs); + body = builder_->Normalize(body); + body = builder_->Normalize(SeqExpr({new_block}, body)); + Map attrs; + attrs.Set(tvm::attr::kGlobalSymbol, name_hint_); + attrs.Set(tvm::relax::attr::kPrimitive, Integer(1)); + function_ = Function(/*params=*/params_, // + /*body=*/body, // + /*ret_type=*/body->checked_type_, + /*ret_shape=*/RuntimeDepShape(), + /*attrs=*/DictAttrs(attrs)); + } + + /*! \brief The original bindings of the function */ + Array bindings_; + /*! \brief The parameters of the function */ + Array params_; + /*! \brief The arguments to call the function on the caller side */ + Array arguments_; + /*! \brief The name for the fused function */ + String name_hint_ = "fused"; + /*! \brief The constructed Relax function */ + Function function_{nullptr}; + + private: + /*! + * \brief Check whether the input expression is defined within this function. If not, create a new + * parameter for the expression. + * \param expr The expression to be checked + */ + void CheckDefAndUpdateParam(const Expr& expr) { + // If the expression has already served as an argument, no need to create another one for it. + auto it = std::find(arguments_.begin(), arguments_.end(), expr); + if (it != arguments_.end()) { + return; + } + + // If the expression is not a variable or is a undefined variable, it should be populated as a + // parameter of the relax function. + const auto* var = expr.as(); + if (var == nullptr || defined_vars_.count(var) == 0) { + String name{nullptr}; + if (var != nullptr) { + name = var->name_hint(); + } else { + name = String("param_" + std::to_string(n_param_for_const_++)); + } + + Var param(std::move(name), // + /*shape_annotation=*/NullOpt, // + /*type_annotation=*/expr->checked_type_); + param->shape_ = expr->shape_; + arguments_.push_back(expr); + params_.push_back(param); + } + } + + Expr VisitExpr(const Expr& expr) final { + // If the expression serves as an argument, return its correspondng parameter. + auto it = std::find(arguments_.begin(), arguments_.end(), expr); + if (it != arguments_.end()) { + return params_[it - arguments_.begin()]; + } + // Otherwise, recurse into this expression. + return ExprMutator::VisitExpr(expr); + } + + private: + /*! \brief The variables defined in this function */ + std::unordered_set defined_vars_; + /*! \brief The number of parameters reserved for constants */ + int n_param_for_const_ = 0; + /*! \brief The output vars */ + std::unordered_set output_vars_; +}; + +/*! + * \brief The ExprMutator used to fuse the operators in Relax functions + * \details Given the partition results on the indexed-forward graph, for each group whose size is + * larger than one, we create a new grouped function for it, containing all bindings in that group. + * And we substitute the bindings in a group with a single function call to the newly created + * grouped function. The workflow of this ExprMutator is: for each dataflow block, + * - we go through the bindings one by one. For each binding, if it is in a group whose size is + * larger than one, we add the binding to the function of the group it is in and update the + * parameters and arguments of that function; + * - then we finalize all the grouped functions by updating their bindings using BlockBuilder; + * - lastly, we go through the bindings again and substitute the bindings in a group with a single + * call to the corresponding grouped function. + * + * After transforming a Relax function, we update the function in the IRModule. Besides, we add all + * newly created grouped function to the IRModule. + */ +class OperatorFusor : public ExprMutator { + public: + /*! + * \brief Construct a new operator fusor. Given the indexed-forward graph and the graph partition + * result on that graph, the constructor creates a mapping from each leaf AST object + * (e.g. parameters, variables, constants) to the group of the node corresponding to the object + * in the graph. + * \param mod The IRModule to be transformed + * \param graph The indexed-forward graph of the input IRModule + * \param groups The grouped result of the group partition on the input indexed-forward graph. + */ + explicit OperatorFusor(IRModule mod, const IndexedForwardGraph& graph, + const std::vector& groups) + : ExprMutator(mod), mod_(std::move(mod)) { + for (int nid = 0; nid < static_cast(graph.post_dfs_order.size()); ++nid) { + GraphPartitioner::Group* group_root = groups[nid]->FindRoot(); + ICHECK(group_root != nullptr); + ICHECK(graph.post_dfs_order[nid]->ref != nullptr); + obj2group_[graph.post_dfs_order[nid]->ref] = group_root; + } + } + + /*! + * \brief The main transformation on the IRModule + * \return The new IRModule after transformation + */ + IRModule Transform() { + for (const auto& kv : mod_->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + // Only visit Relax function without attr kPrimitive. + if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { + auto updated_func = Downcast(VisitExpr(func)); + builder_->UpdateFunction(gv, updated_func); + } + } + return builder_->GetContextIRModule(); + } + + private: + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + if (const auto* df_block = block.as()) { + return VisitBindingBlock_(df_block); + } + // We skip ordinary binding blocks since they might be impure (with side effect or control flow) + return block; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { + group2func_.clear(); + + // Step 1. Collect the bindings for each grouped function. + CollectFuncBindings(block->bindings); + + // Step 2. Collect all group's boundary (i.e. the output vars for each group) + CollectFuncBoundary(block->bindings); + + // Step 3. Create the grouped function for each group. + for (auto& kv : group2func_) { + FunctionCreator& creator = kv.second; + creator.CreateFunction(); + } + + // Step 4. Start generating the new binding block. + // - For groups with single binding, we directly recurse into the binding and emit the new one. + // - For groups with multiple bindings, we emit the call to the grouped function only when + // visiting the last binding of the group, because only by doing this we don't break the + // dependencies among the bindings of different groups. And therefore, we will skip all but the + // last binding of the group. + builder_->BeginDataflowBlock(); + for (size_t i = 0; i < block->bindings.size(); ++i) { + const Binding& binding = block->bindings[i]; + + // Case 1. If the binding is the only binding in its group, recurse into it and emit the + // transformed binding as usual. + GraphPartitioner::Group* group = GetGroupFromBinding(binding); + if (group->num_nodes == 1) { + VisitBinding(binding); + continue; + } + + const auto& it_creator = group2func_.find(group); + ICHECK(it_creator != group2func_.end()); + const FunctionCreator& func_info = it_creator->second; + + // Case 2. If the binding is not the last binding of the group, we skip it. + if (!func_info.bindings_.back().same_as(binding)) { + continue; + } + + // Case 3. The binding is the last binding of the group. + const auto* var_binding = binding.as(); + ICHECK(var_binding != nullptr) << "The last binding of a group whose size is larger than 1 " + "is supposed to be a variable binding"; + + // Step a. Add the grouped function to the IRModule + GlobalVar gv = builder_->AddFunction(func_info.function_, func_info.name_hint_); + + // Step b. Create the call to the deduplicated function, and then emit the call. + // - If this binding is an output binding, emit an output variable. + // - Otherwise, emit a dataflow variable. + Var new_var; + Call call_to_emit = Call(gv, UpdateArgs(func_info.arguments_)); + + if (var_binding->var->IsInstance()) { + new_var = builder_->Emit(call_to_emit); + } else { + new_var = builder_->EmitOutput(call_to_emit); + } + + // Step c. Update the mapping used for the remapping of the binding variables. + var_remap_[var_binding->var->vid] = new_var; + } + // Step 5. Finish the binding block generation. + return builder_->EndBlock(); + } + + /*! + * \brief Collect the bindings for each grouped function and update the information of the grouped + * function + * \param bindings The bindings to be collected + * \note The function update is done by `AppendBinding(...)` + */ + void CollectFuncBindings(const Array& bindings) { + for (const Binding& binding : bindings) { + // If the binding is the only binding in its group, there is no need to create a new function. + GraphPartitioner::Group* group = GetGroupFromBinding(binding); + if (group->num_nodes == 1) { + continue; + } + // Add the binding to the grouped function it's in, and update the function information + // accordingly. + FunctionCreator& func_info = group2func_[group]; + func_info.AppendBinding(binding); + } + } + + void CollectFuncBoundary(const Array& bindings) { + for (const Binding& binding : bindings) { + // Step 1. Get current binding's group + GraphPartitioner::Group* cur_group = GetGroupFromBinding(binding); + + // Step 2. Collect all used vars in the binding value and update bondary. + // - If the var's group is same as the binding's, the var is defined in the same group + // - If the var's group is different with the binding's, the var must be the output from + // another group. Mark it to be the group output. + auto update_boundary = [this, &cur_group](const Expr& e) { + if (e->IsInstance()) { + const Var& used_var = Downcast(e); + GraphPartitioner::Group* producer_group = GetGroupFromVar(used_var); + // Only check those group defined before. + // Skip the vars from input or groups with single binding. + if (producer_group != cur_group && + group2func_.find(producer_group) != group2func_.end()) { + FunctionCreator& producer_func_info = group2func_[producer_group]; + producer_func_info.AppendOutput(used_var); + } + } + }; + if (const auto* var_binding = binding.as()) { + PostOrderVisit(var_binding->value, update_boundary); + } else { + const auto* match_shape = binding.as(); + ICHECK_NOTNULL(match_shape); + PostOrderVisit(match_shape->value, update_boundary); + } + } + } + + /*! + * \brief Get the group which the input binding is in + * \param binding The binding to be queried + * \return The pointer to the group which the input binding is in + */ + GraphPartitioner::Group* GetGroupFromBinding(const Binding& binding) { + Var var{nullptr}; + if (const auto* var_binding = binding.as()) { + var = var_binding->var; + } else { + const auto* match_shape = binding.as(); + ICHECK(match_shape != nullptr); + var = match_shape->var; + } + return GetGroupFromVar(var); + } + + /*! + * \brief Get the group which the input var is in + * \param Var The var to be queried + * \return The pointer to the group which the input var is in + */ + GraphPartitioner::Group* GetGroupFromVar(const Var& var) { + const auto& it_group = obj2group_.find(var.get()); + ICHECK(it_group != obj2group_.end()); + GraphPartitioner::Group* group = it_group->second; + ICHECK(group->FindRoot() == group); + return group; + } + + /*! + * \brief Update the pre-stored arguments according to the variable remapping of the fusor, by + * recursing into each argument + * \param args The arguments to be updated + * \return The updated arguments + */ + Array UpdateArgs(const Array& args) { + Array new_args; + new_args.reserve(args.size()); + for (const Expr& arg : args) { + new_args.push_back(VisitExpr(arg)); + } + return new_args; + } + + private: + /*! \brief The IRModule. */ + IRModule mod_; + /*! \brief Internal arena. */ + support::Arena arena_; + /*! \brief The group assignment map. */ + std::unordered_map obj2group_; + /*! \brief Internal function information map. */ + std::unordered_map group2func_; +}; + +IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { + support::Arena arena; + + // Step 1. Create the indexed-forward graph according to the input IRModule. + IndexedForwardGraph graph = GraphCreator::Create(mod, &arena); + + // Step 2. Partition the graph by applying the fusion algorithm. + std::vector groups = + GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph); + + // Step 3. Transform the IRModule by fusing the operators in accordance with the graph partition + // results. + mod = OperatorFusor(mod, graph, groups).Transform(); + + return mod; +} + +namespace transform { + +Pass FuseOps(int fuse_opt_level) { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { + int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; + auto max_fuse_depth = pc->GetConfig("relax.FuseOps.max_depth", Integer(kMaxFusedOps)); + return relax::FuseOps(m, opt_level, max_fuse_depth.value().IntValue()); + }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseOps", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc new file mode 100644 index 0000000000..ad9274f874 --- /dev/null +++ b/src/relax/transform/fuse_tir.cc @@ -0,0 +1,718 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include "../../relay/analysis/graph_partitioner.h" +#include "../../support/arena.h" +#include "../../tir/ir/functor_common.h" + +namespace tvm { +namespace tir { + +// TODO(Siyuan): move it to somewhere under tir folder +/*! + * \brief Substitute a given source buffer with a given target buffer in statements or expressions. + */ +class BufferSubstituter : private StmtExprMutator { + public: + static Stmt Substitute(const Map& buffer_map, Stmt stmt) { + return BufferSubstituter(buffer_map)(std::move(stmt)); + } + + private: + explicit BufferSubstituter(const Map& buffer_map) { + for (const auto& kv : buffer_map) { + const Buffer& src = kv.first; + const Buffer& tgt = kv.second; + buffer_var_map_[src->data.get()] = tgt; + } + } + + PrimExpr VisitExpr_(const VarNode* _op) final { + auto it = buffer_var_map_.find(_op); + if (it != buffer_var_map_.end()) { + return it->second->data; + } else { + return GetRef(_op); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* _op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); + auto it = buffer_var_map_.find(load->buffer->data.get()); + if (it != buffer_var_map_.end()) { + auto n = make_object(*load.get()); + n->buffer = it->second; + return BufferLoad(n); + } else { + return std::move(load); + } + } + + Stmt VisitStmt_(const BufferStoreNode* _op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); + auto it = buffer_var_map_.find(store->buffer->data.get()); + if (it != buffer_var_map_.end()) { + auto n = CopyOnWrite(store.get()); + n->buffer = it->second; + return BufferStore(n); + } else { + return std::move(store); + } + } + + PrimExpr VisitExpr_(const LoadNode* _op) final { + Load load = Downcast(StmtExprMutator::VisitExpr_(_op)); + auto it = buffer_var_map_.find(load->buffer_var.get()); + if (it != buffer_var_map_.end()) { + auto n = make_object(*load.get()); + n->buffer_var = it->second->data; + return Load(n); + } else { + return std::move(load); + } + } + + Stmt VisitStmt_(const StoreNode* _op) final { + Store store = Downcast(StmtExprMutator::VisitStmt_(_op)); + auto it = buffer_var_map_.find(store->buffer_var.get()); + if (it != buffer_var_map_.end()) { + auto n = CopyOnWrite(store.get()); + n->buffer_var = it->second->data; + return Store(n); + } else { + return std::move(store); + } + } + + Stmt VisitStmt_(const BlockNode* _op) final { + Block block = Downcast(StmtMutator::VisitStmt_(_op)); + + // Define the mutation functions. + auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) { + const Buffer& src_buffer = match_buffer->source->buffer; + auto it = buffer_var_map_.find(src_buffer->data.get()); + if (it != buffer_var_map_.end()) { + return MatchBufferRegion(match_buffer->buffer, + BufferRegion(it->second, match_buffer->source->region)); + } else { + return match_buffer; + } + }; + + auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) { + auto it = buffer_var_map_.find(buffer_region->buffer->data.get()); + return it == buffer_var_map_.end() ? buffer_region + : BufferRegion(it->second, buffer_region->region); + }; + + // Step 1. Mutate `match_buffers`. + Array match_buffers = + MutateArray(block->match_buffers, f_mutate_match_buffers); + // Step 2. Mutate the read/write region. + Array reads = MutateArray(block->reads, f_mutate_read_write_region); + Array writes = MutateArray(block->writes, f_mutate_read_write_region); + + reads = UnionAccessRegion(reads); + writes = UnionAccessRegion(writes); + + if (reads.same_as(block->reads) && // + writes.same_as(block->writes) && // + match_buffers.same_as(block->match_buffers)) { + return std::move(block); + } else { + auto n = CopyOnWrite(block.get()); + n->reads = std::move(reads); + n->writes = std::move(writes); + n->match_buffers = std::move(match_buffers); + return Block(n); + } + } + + private: + /*! \brief Mapping from src buffer.data to tgt buffer. */ + std::unordered_map buffer_var_map_; + /*! \brief The structural equality checker */ + StructuralEqual structural_equal_; + + Array UnionAccessRegion(const Array& regions) const { + // For now we only allow Buffer access the same elements. + // e.g. `[A[vi, vj], A[vi, vj]]` is a legal pattern but need to union to `A[vi, vj]` + // However, `A[vi, vj], A[vi, vj + 1]` is not allow for now. + // Note: the order of return region should remain the same as the first occurance of the region + Array ret; + std::unordered_map buffer_region_set; + + for (const BufferRegion& region : regions) { + auto it = buffer_region_set.find(region->buffer.get()); + if (it == buffer_region_set.end()) { + ret.push_back(region); + buffer_region_set[region->buffer.get()] = region->region; + } else { + ICHECK(structural_equal_(region->region, it->second)); + } + } + + if (ret.size() == regions.size()) { + return regions; + } else { + return ret; + } + } +}; + +/*! \brief A mutator which detect block name duplication and deduplicate the names. */ +class BlockNameDeduplicator : public tir::StmtMutator { + private: + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(tir::StmtMutator::VisitStmt_(op)); + + String name = GetUniqueName(block->name_hint); + + if (name == block->name_hint) { + return std::move(block); + } else { + ObjectPtr n = CopyOnWrite(block.get()); + n->name_hint = std::move(name); + return Stmt(n); + } + } + + String GetUniqueName(const String& prefix) { + String unique_prefix = prefix; + auto it = name_count_.find(prefix); + while (name_count_.count(unique_prefix)) { + unique_prefix = prefix + "_" + std::to_string(++it->second); + } + name_count_[unique_prefix] = 0; + return unique_prefix; + } + + // TODO(relax-team): It should detects the number suffix and do renaming properly + // e.g. GetUniqueName("name1") should return "name2" instead of "name10". + /*! \brief The count map to make block name unique. */ + std::unordered_map name_count_; +}; + +} // namespace tir + +namespace relax { + +class FusedTIRConstructor : public ExprVisitor { + public: + /*! + * \brief Construct a fused TIR PrimFunc from a relax sub-function + * \param mod The IRModule + * \param gv The global var of relax subfunction to be fused into one PrimFunc + * \return The fused TIR PrimFunc + */ + static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) { + FusedTIRConstructor visitor(mod, gv->name_hint); + BaseFunc f = mod->Lookup(gv); + CHECK(f->IsInstance()) + << "Expected relax functions, but got: " << f->GetTypeKey(); + CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive)) + << "Expected a function with attr `kPrimitive`"; + visitor(Downcast(f)); + return visitor.fused_tir_; + } + + private: + explicit FusedTIRConstructor(const IRModule& mod, const String& func_name) + : mod_(mod), func_name_(func_name) {} + + void VisitExpr_(const FunctionNode* func) final { + // Step 1. Create buffers for function params + for (const Var& relax_param : func->params) { + auto ret = CreateParamsAndBuffers(relax_param->checked_type(), // + relax_param->shape(), // + relax_param->name_hint()); + const Array& params = ret.first; + const Array& buffers = ret.second; + ICHECK_EQ(params.size(), buffers.size()); + for (size_t i = 0; i < params.size(); ++i) { + func_info_.buffer_map.Set(params[i], buffers[i]); + func_info_.params.push_back(params[i]); + } + func_info_.expr2buffers.Set(relax_param, buffers); + } + + // Step 2. Visit Function body and create intermediate buffers + ExprVisitor::VisitExpr_(func); + + // Step 3. Create and remap buffers for function output + ICHECK(func->body->IsInstance()) + << "Function body is expected to be a SeqExpr, but got: " << func->body->GetTypeKey(); + Expr body = Downcast(func->body)->body; + auto it = func_info_.expr2buffers.find(body); + ICHECK(it != func_info_.expr2buffers.end()) + << "Fail to detect output buffers for function body"; + const Array& buffers = (*it).second; + for (size_t i = 0; i < buffers.size(); ++i) { + tir::Var param = tir::Var("p_output" + std::to_string(i), PrimType(DataType::Handle())); + func_info_.buffer_map.Set(param, buffers[i]); + func_info_.params.push_back(param); + func_info_.output_buffers.insert(buffers[i].get()); + } + + // Step 4. Create PrimFunc + fused_tir_ = ConstructFunc(); + } + + void VisitBinding_(const VarBindingNode* binding) final { + // Update expr2buffers by visiting values. + this->VisitExpr(binding->value); + auto it = func_info_.expr2buffers.find(binding->value); + if (it != func_info_.expr2buffers.end()) { + // assign binding var to the buffers of the value + func_info_.expr2buffers.Set(binding->var, (*it).second); + } else { + LOG(FATAL) << "Unsupported binding value: " << binding->value; + } + } + + void VisitBinding_(const MatchShapeNode* match_shape) final { + // TODO(relax-team): support match shape in primitive functions; + LOG(FATAL) << "MatchShape is unsupported in primitive functions"; + } + + void VisitExpr_(const CallNode* call) final { + ExprVisitor::VisitExpr_(call); + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + ICHECK(call->op == call_tir_op_) + << "Only call_tir is supported in primitive function, but got: " << GetRef(call); + + // Step 1. Get Global var and PrimFunc + GlobalVar gv = Downcast(call->args[0]); + Optional prim_func_ = GetPrimFunc(gv); + ICHECK(prim_func_.defined()) << "Cannot find the prim_func of the call_tir in the module: " + << gv; + // Step 2. Renew all vars/buffer definitions and blocks to avoid duplication + tir::PrimFunc prim_func = tir::RenewDefs(prim_func_.value()); + + // Step 3. Check functions are all schedulable funcs. i.e. the body of func is root block + // TODO(Siyuan): support un-schedulable functions. + ICHECK(prim_func->body->IsInstance()) + << "Only schedulable functions (whose body is the root block) can be fused"; + const tir::BlockRealize& root_realize = Downcast(prim_func->body); + const tir::Block& root_block = root_realize->block; + + // Step 4. Add all the original alloc_buffers and body to the fused function. + func_info_.alloc_buffers.insert(func_info_.alloc_buffers.end(), + root_block->alloc_buffers.begin(), + root_block->alloc_buffers.end()); + func_info_.bodies.push_back(root_block->body); + + // Step 5. Map input arguments to buffer + MapInputBuffer(prim_func, call->args[1]); + size_t num_output_buffers = GetCallTIROutputSize(call); + AllocateIntermediateBuffer(GetRef(call), prim_func, num_output_buffers); + // Update fused func name + func_info_.global_name += "_" + gv->name_hint; + } + + void VisitExpr_(const TupleGetItemNode* tuple_get_item) final { + ExprVisitor::VisitExpr_(tuple_get_item); + auto it = func_info_.expr2buffers.find(tuple_get_item->tuple); + if (it != func_info_.expr2buffers.end()) { + int begin_buf_idx = 0; + int end_buf_idx = 0; + const TupleType& tuple_type = Downcast(tuple_get_item->tuple->checked_type()); + for (int i = 0; i < tuple_get_item->index; ++i) { + begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]); + } + end_buf_idx = begin_buf_idx + GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]); + func_info_.expr2buffers.Set( + GetRef(tuple_get_item), + {(*it).second.begin() + begin_buf_idx, (*it).second.begin() + end_buf_idx}); + } + } + + void VisitExpr_(const TupleNode* tuple) final { + ExprVisitor::VisitExpr_(tuple); + Array buffers; + for (const Expr& expr : tuple->fields) { + auto it = func_info_.expr2buffers.find(expr); + if (it != func_info_.expr2buffers.end()) { + buffers.insert(buffers.end(), (*it).second.begin(), (*it).second.end()); + } + } + if (!buffers.empty()) { + func_info_.expr2buffers.Set(GetRef(tuple), buffers); + } + } + + void VisitExpr_(const ConstantNode* op) final { + LOG(FATAL) << "Relax.Constant is not supported in primitive functions."; + } + + /********** Helper Functions **********/ + + /*! + * \brief Pattern match op to a TIR function and look it up. + * \return The TIR function, or NullOpt if patter match fails. + */ + Optional GetPrimFunc(const GlobalVar& global_var) { + // NOTE: as check works for nullptr(returns null) + Optional base_func = mod_->functions.Get(global_var); + if (auto* pfunc = base_func.as()) { + return GetRef(pfunc); + } else { + return NullOpt; + } + } + + /*! + * \brief Get the number of outputs for a call_tir node. + * \return The number of outputs. + */ + static size_t GetCallTIROutputSize(const CallNode* call) { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + ICHECK(call->op.same_as(call_tir_op_)); + const Expr& output_shapes = call->args[2]; + if (const auto* tuple_output_shapes = output_shapes.as()) { + return tuple_output_shapes->fields.size(); + } else { + return 1; + } + } + + /*! \brief Map old TIR func param buffer to new buffer, and then update `buffer_subst_map` */ + void MapArgsToBuffer(const Array args, const Array& buffers) { + size_t buffer_idx = 0; + for (const Expr& arg : args) { + if (const auto* v = arg.as()) { + auto it = func_info_.expr2buffers.find(GetRef(v)); + // Substitute the buffer with the already allocated one if it is an intermediate var + if (it != func_info_.expr2buffers.end()) { + for (const tir::Buffer& target_buffer : (*it).second) { + ICHECK_LT(buffer_idx, buffers.size()); + const tir::Buffer& buffer = buffers[buffer_idx]; + // TODO(relax-team): Add support for symbolic shape fusion + for (const PrimExpr& shape_expr : buffer->shape) { + ICHECK(shape_expr.as()) << "Only support constant shape fusion for now"; + } + func_info_.buffer_subst_map.Set(buffer, target_buffer); + buffer_idx++; + } + } + } + } + // Make sure every buffers are maped. + ICHECK_EQ(buffer_idx, buffers.size()); + } + + /*! + * \brief Update buffer mapping `func_info_.buffer_subst_map` for input args + * \param func The old TIR PrimFunc + * \param output_size The number of output params. All output params are at the end of param list. + */ + void MapInputBuffer(const tir::PrimFunc& func, const relax::Expr& args) { + Array arg_list; + Array buffer_list; + if (const auto* arg_tuple = args.as()) { + arg_list = arg_tuple->fields; + } else { + arg_list = {args}; + } + + ICHECK_GE(func->params.size(), arg_list.size()); + for (size_t i = 0; i < arg_list.size(); ++i) { + const tir::Var& param = func->params[i]; + const tir::Buffer& buffer = func->buffer_map.at(param); + buffer_list.push_back(buffer); + } + + MapArgsToBuffer(arg_list, buffer_list); + } + + /*! + * \brief Allocate buffer(s) and update `func_info.expr2buffers` if the PrimFunc output(s) are + * intermediate results. + * \param expr The relax Expr, which can be binding vars or binding values. + * \param func The old TIR PrimFunc + * \param output_size The number of output params. All output params are at the end of param list. + */ + void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func, size_t output_size) { + size_t n = func->params.size(); + ICHECK_GE(n, output_size); + // Allocate intermediate buffer + Array alloc_buffers; + for (size_t i = 0; i < output_size; ++i) { + const tir::Var& param = func->params[n - output_size + i]; + const tir::Buffer& buffer = func->buffer_map.at(param); + func_info_.alloc_buffers.push_back(buffer); + alloc_buffers.push_back(buffer); + } + // Update expr2buffers + func_info_.expr2buffers.Set(expr, alloc_buffers); + } + + /*! + * \brief Create an TIR func params and buffers with specified relax type and shape + * \param type The specified relax type, which can be DynTensorType or Tuple + * \param shape The specified shape, which can be ShapeExpr or Tuple + * \param name_hint The name hint for params and buffers + * \param index The index used for unique name_hint if type is Tuple. + * -1 means no need to add postfix since the relax param is not a Tuple. + * \return The created TIR func params and buffers + */ + static std::pair, Array> CreateParamsAndBuffers( + Type type, relax::Expr shape, const String& name_hint, int index = -1) { + Array params; + Array buffers; + if (const auto* shape_expr = shape.as()) { + // Case 1. the relax param is a DynTensor, we directly create a tir var and buffer + ICHECK(type->IsInstance()); + String name = index == -1 ? name_hint : name_hint + "_" + std::to_string(index); + DataType dtype = Downcast(type)->dtype; + tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name); + // Differentiate buffer name and param name by adding prefix `v_` to param + // Every symbol should be unique in TVMScript, and Buffer is used more than param + // So we decide to make sure buffer names have better readability. + tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle())); + params.push_back(std::move(param)); + buffers.push_back(std::move(buffer)); + } else if (const auto* shape_tuple = shape.as()) { + // Case 2. the relax param is a Tuple, we recursively visit each field until it's a DynTensor + ICHECK(type->IsInstance()); + TupleType tuple_type = Downcast(type); + // Enable postfix + if (index == -1) index = 0; + for (size_t i = 0; i < shape_tuple->fields.size(); ++i) { + auto ret = + CreateParamsAndBuffers(tuple_type->fields[i], shape_tuple->fields[i], name_hint, index); + const Array& ret_params = ret.first; + const Array& ret_buffers = ret.second; + ICHECK_EQ(ret_params.size(), ret_buffers.size()); + // Adding tuple field results to the end of params and buffers. + params.insert(params.end(), ret_params.begin(), ret_params.end()); + buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end()); + index += ret_params.size(); + } + } else { + ICHECK(false) << "shapes are expected to be ShapeExprNode or TupleNode"; + } + return std::make_pair(params, buffers); + } + + /*! + * \brief Construct fused TIR func with collected FuseFuncInfo + * \return The fused TIR + */ + tir::PrimFunc ConstructFunc() { + Map attr_map; + attr_map.Set("tir.noalias", tir::const_true()); + ICHECK(func_info_.global_name != "fused"); + // TODO(relax-team): remove global_symbol later. + attr_map.Set("global_symbol", String(func_info_.global_name)); + // Remove output buffers from func_info_.alloc_buffers + Array alloc_buffers; + for (const tir::Buffer& buf : func_info_.alloc_buffers) { + if (func_info_.output_buffers.count(buf.get()) == 0) { + alloc_buffers.push_back(buf); + } + } + tir::Stmt body = tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies)); + body = tir::BufferSubstituter::Substitute(func_info_.buffer_subst_map, body); + body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, alloc_buffers); + body = tir::BlockRealize({}, Bool(true), Downcast(body)); + tir::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, + Optional>(), DictAttrs(attr_map)); + return func; + } + + /*! \brief Get DynTensor numbers from recursive Tuples. */ + static size_t GetTotalTensorSize(const Type& type) { + if (type.as()) { + return 1; + } else if (const auto* tuple_type = type.as()) { + size_t num = 0; + for (const Type& type : tuple_type->fields) { + num += GetTotalTensorSize(type); + } + return num; + } else { + LOG(FATAL) << "DynTensorType and TupleType are expect, but got: " << type; + return 0; + } + } + + /********** Function Info **********/ + + /*! \brief auxiliary information for FuseTIR */ + struct FuseFuncInfo { + /*! \brief The arguments for calling prim_func */ + Array arguments; + /*! + * \brief The map from each dataflow var (intermediate var) to the corresponding buffers + * allocated in the fused func + */ + Map> expr2buffers; + /*! \brief The buffers to allocate in the fused func*/ + Array alloc_buffers; + /*! \brief The bodies of the original funcs, which is also the body of the fused func. */ + Array bodies; + /*! \brief The params of the fused function*/ + Array params; + /*! + * \brief The map from buffer in original functions to corresponding buffer in the fused + * function + */ + Map buffer_subst_map; + /*! \brief The `buffer_map` in the fused function*/ + Map buffer_map; + /*! \brief The output buffers in the function buffer_map*/ + std::unordered_set output_buffers; + /*! \brief The name of the fused function */ + std::string global_name = "fused"; + }; + + /*! \brief The IRModule */ + const IRModule& mod_; + /*! \brief The name hint for the input func. */ + String func_name_; + /*! \brief The helper info to fuse TIR prim_func */ + FuseFuncInfo func_info_; + /*! \brief The tir function after fusion*/ + tir::PrimFunc fused_tir_; +}; + +/*! + * \brief The helper class to fuse TIR functions and build a new module which calls the fused TIR. + */ +class TIRFuseMutator : public ExprMutator { + public: + static IRModule Transform(const IRModule& mod) { + // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder. + TIRFuseMutator mutator(mod); + // Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_` + for (const auto& kv : mod->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + // Only fuse primitive relax functions + if (func->IsInstance() && func->HasNonzeroAttr(attr::kPrimitive)) { + tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv); + mutator.fused_tir_funcs_.Set(gv, fused_tir); + } + } + + // Step 2. Update all non-primitive relax functions and add it, with the dependent function, + // into the new IRModule + for (const auto& kv : mod->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { + relax::Function update_func = Downcast(mutator.VisitExpr(func)); + mutator.builder_->AddFunction(update_func, gv->name_hint); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* op) final { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + Call call = Downcast(ExprMutator::VisitExpr_(op)); + if (call->op->IsInstance()) { + // Case 1. It is a relax cross function call + GlobalVar old_gv = Downcast(call->op); + auto it = fused_tir_funcs_.find(old_gv); + if (it != fused_tir_funcs_.end()) { + const tir::PrimFunc& fused_tir = (*it).second; + // Case 1.1. It calls a primitive relax function, update the call into a call_tir + GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, old_gv->name_hint); + // Step a. Flatten all args since call_tir does not support Tuple value. + Array arg_list; + for (const Expr& arg : call->args) { + Array flattened = FlattenArg(arg); + arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); + } + // Step b. Create call_tir + Array call_args = {fused_tir_gv, Tuple(arg_list), call->shape()}; + return Call(call_tir_op_, call_args, call->attrs, {call->checked_type()}); + } else { + // Case 1.2. The callee function is not primitive, nothing to do. + return call; + } + } else if (call->op == call_tir_op_) { + // Case 2. It is a call_tir, re-emit the PrimFunc. + GlobalVar gv = Downcast(call->args[0]); + tir::PrimFunc func = Downcast(mod_->Lookup(gv)); + GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint); + return Call(call->op, {new_gv, call->args[1], call->args[2]}, call->attrs, call->type_args, + call->span); + } else { + // Case 3. CallNode in other types. Leave it as it is. + return call; + } + } + + /********** Helper Functions **********/ + + /*! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */ + Array FlattenArg(const Expr& arg) { + if (const auto* tuple_shape = arg->shape().as()) { + Array arg_list; + for (size_t i = 0; i < tuple_shape->fields.size(); ++i) { + Expr new_arg = builder_->Emit(TupleGetItem(arg, i)); + Array flattened = FlattenArg(new_arg); + arg_list.insert(arg_list.end(), flattened.begin(), flattened.end()); + } + return arg_list; + } else { + return {arg}; + } + } + + private: + /*! \brief The IRModule */ + const IRModule& mod_; + /*! \brief The map from global var of primitive relax function to generated prim func. */ + Map fused_tir_funcs_; +}; + +IRModule FuseTIR(IRModule mod) { + mod = TIRFuseMutator::Transform(mod); + return mod; +} + +namespace transform { + +Pass FuseTIR() { + runtime::TypedPackedFunc pass_func = // + [=](IRModule m, PassContext pc) { return relax::FuseTIR(m); }; + return CreateModulePass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"FuseTIR", // + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc new file mode 100644 index 0000000000..9a0ae5f098 --- /dev/null +++ b/src/relax/transform/lambda_lift.cc @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/lambda_lift.cc + * \brief Lift local functions into global functions. + */ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/* The goal of this class is to lift out any nested functions into top-level + * functions. + * + * We will lift a function out into a global which takes the set of the free + * vars and then return the new created function. + */ +class LambdaLifter : public ExprMutator { + public: + explicit LambdaLifter(const IRModule& module) : ExprMutator(module) { mod_ = module; } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + if (auto const* var = call_node->op.as()) { + bool has_closure = HasClosure(GetRef(var)); + auto val = builder_->LookupBinding(GetRef(var)); + // Call "relax.invoke_closure" to invoke closure + if (has_closure && val.as()) { + Var clo_arg = GetRef(var); + if (this->var_remap_.find(var->vid) != this->var_remap_.end()) { + clo_arg = this->var_remap_.at(var->vid); + } + return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {}, + {call_node->checked_type_}); + } + } + if (auto global_var_node = call_node->op.as()) { + String rec_name = global_var_node->name_hint; + auto global_var = GetRef(global_var_node); + auto it = lambda_map_.find(global_var); + if (it != lambda_map_.end()) { + // flatten nested call, e.g. call(y)(x) -> call(x, y)) + Array new_args; + for (const auto arg : call->args) { + new_args.push_back(arg); + } + if (const auto* nest_call = it->second.as()) { + for (const auto arg : nest_call->args) { + new_args.push_back(arg); + } + return Call(nest_call->op, new_args, call_node->attrs, call_node->type_args); + } + return Call(it->second, call->args, call_node->attrs, call_node->type_args); + } + } + return std::move(call); + } + + Expr VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + + // TODO(@yongwww): consider appending inner func name into the lifted func name + String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++); + auto global = GlobalVar(lift_func_name); + Array captured_vars = FreeVars(func); + recur_vars_ = CalledGlobalVars(func); + auto all_global_vars = AllGlobalVars(func); + + Array typed_captured_vars; + Map rebinding_map; + for (auto free_var : captured_vars) { + Var var = Var(free_var->name_hint(), NullOpt, free_var->checked_type_, free_var->span); + var->shape_ = free_var->shape_; + typed_captured_vars.push_back(var); + rebinding_map.Set(free_var, var); + } + + // recursive call + if (!recur_vars_.empty()) { + if (!captured_vars.empty()) { + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); + } + lambda_map_.emplace(recur_vars_.back(), Call(global, fvs)); + } else { + if (recur_vars_.size() > 0) { + lambda_map_.emplace(recur_vars_.back(), global); + } + } + } + + tvm::Array params; + bool all_params_unchanged = true; + for (Var param : func_node->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + all_params_unchanged &= param.same_as(new_param); + } + + Expr body = this->VisitWithNewScope(func_node->body); + Expr visited_func; + + if (all_params_unchanged && body.same_as(func_node->body)) { + visited_func = GetRef(func_node); + } else if (body->checked_type_.as()) { + // make_closure was introduced + // TODO(@sslyu): Determine if we can fill in the return shape + visited_func = + Function(params, body, body->checked_type_, RuntimeDepShape(), func_node->attrs); + } else { + visited_func = + Function(params, body, func_node->ret_type, RuntimeDepShape(), func_node->attrs); + } + auto new_func = Downcast(visited_func); + + Function lifted_func; + bool is_closure = IsClosure(captured_vars); + if (!is_closure) { + lifted_func = Function( + /*params=*/new_func->params, + /*body=*/new_func->body, + /*ret_type=*/new_func->ret_type, + /*ret_shape=*/new_func->ret_shape, + /*attrs=*/new_func->attrs, + /*span=*/new_func->span); + } else { + // Flatten the Closure + std::vector closure_params; + closure_params.reserve(func->params.size() + typed_captured_vars.size()); + for (size_t i = 0; i < func->params.size(); ++i) { + closure_params.emplace_back(func->params[i]); + } + for (size_t i = 0; i < typed_captured_vars.size(); ++i) { + closure_params.emplace_back(typed_captured_vars[i]); + } + + lifted_func = Function(/*params=*/closure_params, + /*body=*/Bind(new_func->body, rebinding_map), + /*ret_type=*/new_func->ret_type, + /*ret_shape=*/new_func->ret_shape, + /*attrs=*/new_func->attrs, + /*span=*/func->span); + + Array param_types; + for (Var param : closure_params) { + CHECK(param->checked_type_.defined()) + << "relax.Function requires params to contain checked_type_"; + param_types.push_back(param->checked_type_); + } + } + lifted_func = WithAttr(std::move(lifted_func), tvm::attr::kGlobalSymbol, lift_func_name); + + ICHECK(lifted_func.defined()); + + // Add the lifted function to the module. + builder_->UpdateFunction(global, lifted_func); + + if (!is_closure) { + return std::move(global); + } else { + // If we need to allocate a closure, + // we pass the variables in its environment here. + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); + } + // Call make_closure intrinsic + return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {}); + } + } + + bool HasClosure(const Var& var) { + auto val = builder_->LookupBinding(var); + if (const auto* value = val.as()) { + IRModule ctx_mod = builder_->GetContextIRModule(); + ICHECK(ctx_mod->functions.size() > 0); + BaseFunc func = ctx_mod->Lookup(GetRef(value)); + if (const auto* func_node = func.as()) { + if (const auto* call_node = func_node->body.as()) { + if (call_node->op == make_closure_op_) { + return true; + } + } + } + } else if (const auto* func_node = val.as()) { + if (const auto* call_node = func_node->body.as()) { + if (call_node->op == make_closure_op_) { + return true; + } + } + } else if (const auto* call_node = val.as()) { + // recursive call + auto op = call_node->op; + if (make_closure_op_ == op) { + return true; + } + if (const auto* lv = op.as()) { + return HasClosure(GetRef(lv)); + } + } + return false; + } + + bool IsClosure(const Array& captured_vars) { return captured_vars.size() > 0; } + + IRModule Lift() { + auto glob_funcs = mod_->functions; + for (auto pair : glob_funcs) { + if (auto* n = pair.second.as()) { + auto func = GetRef(n); + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->ret_shape, + func->attrs); + builder_->UpdateFunction(pair.first, func); + } + } + return builder_->GetContextIRModule(); + } + + private: + std::unordered_map lambda_map_; + Array recur_vars_; + IRModule mod_; + size_t lift_func_num_ = 0; + /*! \brief Cache ops that would be used later to reduce lookup overhead. */ + const Op& make_closure_op_ = Op::Get("relax.make_closure"); + const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); +}; + +namespace transform { + +Pass LambdaLift() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::LambdaLifter(m).Lift(); }; + return CreateModulePass(pass_func, 1, "LambdaLift", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc new file mode 100644 index 0000000000..7af7b678be --- /dev/null +++ b/src/relax/transform/meta_schedule.cc @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/meta_schedule.cc + * \brief Pass for meta_schedule tuning + */ +#include +#include +#include +#include + +#include "../../printer/text_printer.h" + +namespace tvm { +namespace relax { +namespace transform { + +class MetaScheduleTuner { + public: + explicit MetaScheduleTuner(Target target, String work_dir, Integer max_trials_global, + Map params = {}) + : target_(target), + work_dir_(work_dir), + max_trials_global_(max_trials_global), + params_(params) { + candgen_func_ = runtime::Registry::Get("relax.tuning_api.default_generate_candidate"); + ICHECK(candgen_func_) << "Default candidate generation function is not found."; + normalize_mod_func_ = runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + } + + // TODO(@sunggg): Currently, only supports basic arguments. + IRModule TuneIRMod(IRModule mod, transform::PassContext ctx) { + Trace trace = Downcast(ctx->GetCurrentTrace()); + ctx->PopTrace(); + Choice choice("tvm.meta_schedule.tune_relax", {params_, target_, work_dir_, max_trials_global_}, + "relax.tuning_api.Choice.default_constr_func", {}); + Knob knob("meta_schedule.tune_irmod", {{"0", choice}}); + Array candidates = (*candgen_func_)(Array({knob}), trace); + ICHECK(candidates.size() == 1); + Trace best_trace = candidates[0]; + ctx->PushTrace(best_trace); + // since we separate tuning from application, return original IRModule + return mod; + } + + // TODO(@sunggg): Currently, only supports basic arguments. + tir::PrimFunc TuneTIR(tir::PrimFunc f, transform::PassContext ctx) { + // TODO(@sunggg): Whenever we tune tir, assume we start a new trace w/o pushing to the trace + // stack. Revisit later when we collect more usecases. + Trace trace = Trace((*normalize_mod_func_)(f), {}, {}); + + Choice choice("tvm.meta_schedule.tune_tir", {target_, work_dir_, max_trials_global_}, + "relax.tuning_api.Choice.default_constr_func", {}); + Knob knob("meta_schedule.tune_primfunc", {{"0", choice}}); + Array candidates = (*candgen_func_)(Array({knob}), trace); + ICHECK(candidates.size() == 1); + // since we separate tuning from application, return original IRModule + return f; + } + + private: + Target target_; + String work_dir_; + Integer max_trials_global_; + Map params_; + const runtime::PackedFunc* candgen_func_; + const runtime::PackedFunc* normalize_mod_func_; +}; + +Pass MetaScheduleApplyDatabase(Optional work_dir) { + using tvm::meta_schedule::Database; + Target target = Target::Current(false); + const runtime::PackedFunc* normalize_mod_func_ = + runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); + ICHECK(normalize_mod_func_) << "Normalization function is not found."; + + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext ctx) { + Database database; + if (Database::Current().defined()) { + database = Database::Current().value(); + } else { + ICHECK(work_dir.defined()); + String path_workload = work_dir.value() + "/database_workload.json"; + String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; + LOG(WARNING) << "Creating JSONDatabase. Workload at: " << path_workload + << ", Tuning records at: " << path_tuning_record; + database = meta_schedule::Database::JSONDatabase(path_workload, path_tuning_record, true); + } + + Map result; + for (const auto& iter : mod->functions) { + GlobalVar gv = iter.first; + BaseFunc base_func = iter.second; + if (const auto* prim_func_node = base_func.as()) { + tir::PrimFunc prim_func = GetRef(prim_func_node); + // Global symbol has to be defined. + Optional gsymbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(gsymbol.defined()); + + IRModule tir_mod = (*normalize_mod_func_)(prim_func); + if (Optional sch = database->QuerySchedule(tir_mod, target, gv->name_hint)) { + IRModule new_mod = sch.value()->mod(); + ICHECK_EQ(new_mod->functions.size(), 1); + BaseFunc new_base_func = (*new_mod->functions.begin()).second; + ICHECK(new_base_func->IsInstance()); + tir::PrimFunc new_prim_func = Downcast(new_base_func); + // copy the original attrs + new_prim_func = WithAttrs(std::move(new_prim_func), {prim_func->attrs->dict}); + result.Set(gv, new_prim_func); + continue; + } else { + LOG(WARNING) << "Tuning record is not found for primfunc: " << gsymbol.value(); + } + } + result.Set(gv, base_func); + } + return IRModule(result, // functions + {}, // type_definitions + {}, // import_set + {}, // map + mod->attrs); // attrs); + }; + return CreateModulePass(pass_func, 0, "MetaScheduleApplyDatabase", {}); +} + +Pass MetaScheduleTuneIRMod(Map params, String work_dir, + Integer max_trials_global) { + Target target = Target::Current(false); + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext ctx) { + return MetaScheduleTuner(target, work_dir, max_trials_global, params).TuneIRMod(m, ctx); + }; + return CreateModulePass(/*pass function*/ pass_func, /*opt level*/ 0, + /*pass name*/ "MetaScheduleTuneIRModule", + /*required*/ {}, + /*traceable*/ true); +} + +Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { + Target target = Target::Current(false); + runtime::TypedPackedFunc pass_func = + [=](tir::PrimFunc f, IRModule mod, PassContext ctx) { + return MetaScheduleTuner(target, work_dir, max_trials_global).TuneTIR(f, ctx); + }; + return tir::transform::CreatePrimFuncPass(/*pass function*/ pass_func, /*opt level*/ 0, + /*pass name*/ "MetaScheduleTuneTIR", + /*required*/ {}, + /*traceable*/ true); +} + +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyDatabase") + .set_body_typed(MetaScheduleApplyDatabase); +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneIRMod").set_body_typed(MetaScheduleTuneIRMod); +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneTIR").set_body_typed(MetaScheduleTuneTIR); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc new file mode 100644 index 0000000000..8beb2b6b5a --- /dev/null +++ b/src/relax/transform/normalize.cc @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/normalize.cc + * \brief Pass for transforming Relax IR to normal form, i.e., the expressions are normalized(no + * nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are + * available. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + +// TODO(@altanh): LCA binding lifting +class NormalizeMutator : public ExprMutatorBase { + public: + NormalizeMutator() { builder_ = BlockBuilder::Create(NullOpt); } + + Expr VisitExpr(const Expr& expr) override { + return builder_->Normalize(ExprMutatorBase::VisitExpr(expr)); + } + + Expr VisitExpr_(const FunctionNode* op) { + Expr body = this->VisitWithNewScope(op->body); + + if (body.same_as(op->body)) { + return GetRef(op); + } else { + return Function(op->params, body, op->ret_type, op->ret_shape, op->attrs); + } + } + + Expr VisitExpr_(const IfNode* op) { + Expr guard = this->VisitExpr(op->cond); + Expr true_b = this->VisitWithNewScope(op->true_branch); + Expr false_b = this->VisitWithNewScope(op->false_branch); + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op); + } else { + return If(guard, true_b, false_b, op->span); + } + } + + Expr VisitWithNewScope(const Expr& expr) { + builder_->BeginBindingBlock(); + Expr ret = this->VisitExpr(expr); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + ret = SeqExpr({prologue}, ret); + } + return ret; + } + + Expr VisitExpr_(const SeqExprNode* op) final { + bool all_blocks_unchanged = true; + Array blocks; + for (auto block : op->blocks) { + BindingBlock new_block = this->VisitBindingBlock(block); + if (!new_block->bindings.empty()) { + blocks.push_back(new_block); + } + all_blocks_unchanged &= block.same_as(new_block); + } + + builder_->BeginBindingBlock(); + Expr body = this->VisitExpr(op->body); + BindingBlock prologue = builder_->EndBlock(); + if (!prologue->bindings.empty()) { + blocks.push_back(prologue); + all_blocks_unchanged = false; + } + + if (all_blocks_unchanged && body.same_as(op->body)) { + return GetRef(op); + } else { + return SeqExpr(blocks, body); + } + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + BindingBlock ret; + if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else if (const auto* node = block.as()) { + ret = VisitBindingBlock_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + } + return ret; + } + + BindingBlock VisitBindingBlock_(const BindingBlockNode* block) { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) { + builder_->BeginDataflowBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } + + void VisitBinding(const Binding& binding) { + if (const auto* node = binding.as()) { + VisitBinding_(node); + } else if (const auto* node = binding.as()) { + VisitBinding_(node); + } else { + LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + } + } + + void VisitBinding_(const VarBindingNode* binding) { + auto emit = [this](VarBinding b) { + if (this->builder_->CurrentBlockIsDataFlow() && !b->var.as()) { + this->builder_->EmitOutput(b); + } else { + this->builder_->Emit(b); + } + }; + + Expr new_value = this->VisitExpr(binding->value); + if (!binding->var->checked_type_.defined()) { + UpdateType(binding->var, new_value->checked_type_); + } + if (!binding->var->shape_.defined()) { + UpdateShape(binding->var, new_value->shape_); + } + if (new_value.same_as(binding->value)) { + emit(GetRef(binding)); + } else { + emit(VarBinding(binding->var, new_value)); + } + } + + void VisitBinding_(const MatchShapeNode* binding) { + Expr new_value = this->VisitExpr(binding->value); + + if (binding->var.defined()) { + if (!binding->var->checked_type_.defined()) { + UpdateType(binding->var, new_value->checked_type_); + } + if (!binding->var->shape_.defined()) { + UpdateShape(binding->var, new_value->shape_); + } + } + if (new_value.same_as(binding->value)) { + builder_->EmitMatchShape(GetRef(binding)); + } else { + builder_->EmitMatchShape(MatchShape(new_value, binding->pattern, binding->var)); + } + } + + private: + /*! \brief Internal block builder to emit bindings during rewriting. */ + BlockBuilder builder_; +}; // namespace relax + +Expr Normalize(const Expr& e) { return NormalizeMutator().VisitExpr(e); } + +namespace transform { + +Pass Normalize() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(Normalize(f)); }; + return CreateFunctionPass(pass_func, 1, "Normalize", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.Normalize").set_body_typed(Normalize); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/removed_unused_funcs.cc b/src/relax/transform/removed_unused_funcs.cc new file mode 100644 index 0000000000..e573a2ac55 --- /dev/null +++ b/src/relax/transform/removed_unused_funcs.cc @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/remove_unused_funcs.cc + * \brief Remove unused global relax functions in a IRModule. + */ + +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +/** + * \brief Detects all the functions that can be possibly called by entry function. + */ +class CallTracer : ExprVisitor { + public: + explicit CallTracer(IRModule mod_) : mod_{mod_}, called_funcs_{}, visiting_{} {} + + void VisitExpr_(const GlobalVarNode* op) final { + called_funcs_.insert(GetRef(op)); + auto func = mod_->Lookup(op->name_hint); + if (const auto* function_node = func.as()) { + VisitExpr(GetRef(function_node)); + } + // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein. + } + + void VisitExpr_(const CallNode* call_node) final { ExprVisitor::VisitExpr_(call_node); } + + void VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + if (visiting_.find(func) == visiting_.end()) { + visiting_.insert(func); + for (auto param : func_node->params) { + ExprVisitor::VisitExpr(param); + } + ExprVisitor::VisitExpr(func_node->body); + } + } + + void Trace(std::string entry) { + called_funcs_.insert(mod_->GetGlobalVar(entry)); + auto main_func = mod_->Lookup(entry); + VisitExpr(main_func); + } + + bool check_if_called(GlobalVar gv) { return called_funcs_.count(gv) > 0; } + + private: + IRModule mod_; + + // Record the names of all encountered functions. + std::unordered_set called_funcs_; + + // Record the expressions that are being visited. + std::unordered_set visiting_; +}; + +/*! + * \brief Remove functions that are not used. + * + * \param mod_ IRModule. + * \param entry_funcs The set of functions that can be entry function. + * + * \return The module with dead functions removed. + */ +IRModule RemoveUnusedFunctions(IRModule mod_, Array entry_funcs) { + auto tracer = CallTracer(mod_); + for (auto entry : entry_funcs) { + tracer.Trace(entry); + } + auto existing_functions = mod_->functions; + for (auto f : existing_functions) { + // If a function has an external linkage type, we do not remove it. + // Otherwise, we check the function and remove it if it is not used anywhere. + if (f.second->GetLinkageType() == LinkageType::kInternal && !tracer.check_if_called(f.first)) { + mod_->Remove(f.first); + } + } + return mod_; +} + +} // namespace relax + +namespace transform { +Pass RemoveUnusedFunctions(Array entry_functions) { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::RemoveUnusedFunctions(m, entry_functions); }; + return CreateModulePass(pass_func, 0, "RemoveUnusedFunctions", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions); + +} // namespace transform +} // namespace tvm diff --git a/src/relax/transform/resolve_globals.cc b/src/relax/transform/resolve_globals.cc new file mode 100644 index 0000000000..202b9ac3d7 --- /dev/null +++ b/src/relax/transform/resolve_globals.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/resolve_globals.cc + * \brief Resolve GlobalVars using string equality. + */ +#include +#include + +namespace tvm { +namespace relax { + +class GlobalVarResolver : public ExprMutator { + public: + GlobalVarResolver(IRModule mod, DiagnosticContext diag_ctx) : mod_(mod), diag_ctx_(diag_ctx) {} + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const GlobalVarNode* gvar) override { + if (!mod_->ContainGlobalVar(gvar->name_hint)) { + return GetRef(gvar); + } + return mod_->GetGlobalVar(gvar->name_hint); + } + + private: + /*! \brief the IRModule used for GlobalVar lookup. */ + IRModule mod_; + DiagnosticContext diag_ctx_; +}; + +namespace transform { + +Pass ResolveGlobals() { + runtime::TypedPackedFunc pass_func = + [](Function f, IRModule m, PassContext pc) { + // TODO(@altanh): make sure pc always has diag_ctx? + GlobalVarResolver resolver(m, pc->diag_ctx.value()); + return Downcast(resolver.VisitExpr(f)); + }; + return CreateFunctionPass(pass_func, 0, "ResolveGlobals", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ResolveGlobals").set_body_typed(ResolveGlobals); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc new file mode 100644 index 0000000000..5d954cf055 --- /dev/null +++ b/src/relax/transform/run_codegen.cc @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/run_codegen.cc + * \brief Run codegen for annotated relax functions. + */ + +#include + +#include + +namespace tvm { +namespace relax { + +class CodeGenRunner : ExprMutator { + public: + explicit CodeGenRunner(IRModule mod, Optional> target_codegens, + Array entry_functions) + : ExprMutator(mod), entry_functions_(std::move(entry_functions)) { + if (target_codegens.defined()) { + for (auto target : target_codegens.value()) { + target_codegens_.insert(target); + } + } + } + + IRModule Run() { + IRModule mod = builder_->GetContextIRModule(); + for (const String& entry_func_name : entry_functions_) { + auto entry_func = mod->Lookup(entry_func_name); + auto gvar = mod->GetGlobalVar(entry_func_name); + builder_->UpdateFunction(gvar, Downcast(VisitExpr(entry_func))); + } + + IRModule out_mod = builder_->GetContextIRModule(); + if (ext_mods_.size()) { + out_mod = WithAttr(out_mod, "external_mods", std::move(ext_mods_)); + } + + return out_mod; + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + if (auto const* gvarnode = call_node->op.as()) { + const GlobalVar gvar = GetRef(gvarnode); + // TODO(@sunggg): Is there any better way to get this func? + Function func = Downcast(builder_->GetContextIRModule()->Lookup(gvar)); + Expr new_op = VisitExpr(func); + if (new_op->IsInstance()) { + Array new_args({new_op}); + Array tmp_args; + for (const auto& arg : call_node->args) { + tmp_args.push_back(VisitExpr(arg)); + } + new_args.push_back(Tuple(tmp_args)); + new_args.push_back(func->body->shape()); + + // TODO(@tvm-team): This is not used only for tir function anymore. Can we rename it? + static const Op& call_op = Op::Get("relax.call_tir"); + + // Remove global symbol and codegen from the function so that it can be removed. + static const runtime::PackedFunc* RemoveFuncAttrFunc = + runtime::Registry::Get("ir.BaseFuncWithoutAttr"); + ICHECK(RemoveFuncAttrFunc); + func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol); + func = (*RemoveFuncAttrFunc)(func, attr::kCodegen); + builder_->UpdateFunction(gvar, func); + + return Call(call_op, new_args, tvm::Attrs(), {func->ret_type}); + } + } + Array new_args; + for (const auto& arg : call_node->args) { + new_args.push_back(VisitExpr(arg)); + } + + return Call(call_node->op, new_args, call_node->attrs, call_node->type_args, call_node->span); + } + + Expr VisitExpr_(const FunctionNode* func_node) override { + Function func = GetRef(func_node); + auto opt_codegen = func->GetAttr(attr::kCodegen); + if (opt_codegen.defined()) { + auto opt_gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(opt_gsymbol.defined()) + << "When a codegen is defined, global symbol should be defined together."; + + String codegen_str = opt_codegen.value(); + // If the current codegen is not in the provided target lists, defer the codegen process. + if (target_codegens_.size() && target_codegens_.count(codegen_str) == 0) { + return GetRef(func_node); + } + + // Start the codegen process. + // Get the codegen with its ffi key. + String codegen_name = "relax.ext." + codegen_str; + auto codegen = runtime::Registry::Get(codegen_name); + ICHECK(codegen) << "Codegen is not found: " << codegen_name << "\n"; + // Store the produced output runtime module in the internal array. + ext_mods_.push_back((*codegen)(func)); + + // Return the external function with given global symbol. + return ExternFunc(opt_gsymbol.value()); + } else { + return ExprMutator::VisitExpr_(func_node); + } + } + + private: + Array entry_functions_; + std::unordered_set target_codegens_; + Array ext_mods_; +}; + +} // namespace relax + +namespace transform { +Pass RunCodegen(Optional> target_codegens, + Array entry_functions) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { + return relax::CodeGenRunner(m, target_codegens, entry_functions).Run(); + }; + return CreateModulePass(pass_func, 0, "RunCodegen", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RunCodegen").set_body_typed(RunCodegen); + +} // namespace transform +} // namespace tvm diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc new file mode 100644 index 0000000000..c7c2627805 --- /dev/null +++ b/src/relax/transform/to_non_dataflow.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/to_non_dataflow.cc + * \brief Transform all dataflow structure to non-dataflow version. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +class ToNonDFMutator : public ExprMutator { + public: + Var VisitVarDef(const Var& var) final { + if (var.as()) { + Var new_var = Var(var->vid, NullOpt, var->checked_type_, var->span); + UpdateShape(new_var, var->shape_); + this->var_remap_[var->vid] = new_var; + return new_var; + } + return var; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } +}; + +Expr ToNonDataflow(const Expr& e) { return ToNonDFMutator().VisitExpr(e); } + +namespace transform { + +Pass ToNonDataflow() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(ToNonDataflow(f)); }; + return CreateFunctionPass(pass_func, 0, "ToNonDataflow", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc new file mode 100644 index 0000000000..177f890d56 --- /dev/null +++ b/src/relax/transform/tuning_api/database.cc @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/tuning_api/database.cc + * \brief Database of tuning APIs. + */ +#include + +#include +#include +#include + +#include "../../../meta_schedule/utils.h" + +namespace tvm { +namespace meta_schedule { + +void JSONFileAppendLine(const String& path, const std::string& line); +std::vector JSONFileReadLines(const String& path, int num_threads, bool allow_missing); + +} // namespace meta_schedule +} // namespace tvm + +namespace tvm { +namespace relax { + +TuningRecord::TuningRecord(Trace trace, Optional> run_secs) { + ObjectPtr n = make_object(); + n->trace = trace; + n->run_secs = run_secs; + this->data_ = n; +} + +ObjectRef TuningRecordNode::AsJSON(bool include_irmod) const { + return Array{trace->AsJSON(include_irmod), // + run_secs}; +} + +TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj) { + Trace trace{nullptr}; + Optional> run_secs{nullptr}; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() == 2); + // Load json[0] => trace + { + const ObjectRef& json_trace = json_array->at(0); + trace = Trace::FromJSON(json_trace); + } + + // Load json[1] => run_secs + if (json_array->at(1).defined()) { + run_secs = meta_schedule::AsFloatArray(json_array->at(1)); + } + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + return TuningRecord(trace, run_secs); +} + +/*! \brief The struct defining comparison function of sorting by mean run seconds. */ +struct SortTuningRecordByMeanRunSecs { + static const constexpr double kMaxMeanTime = 1e10; + + static double Mean(const Array& a) { + if (a.empty()) { + return kMaxMeanTime; + } + double sum = 0.0; + for (const FloatImm& i : a) { + sum += i->value; + } + return sum / a.size(); + } + + bool operator()(const TuningRecord& a, const TuningRecord& b) const { + double a_time = Mean(a->run_secs.value_or({})); + double b_time = Mean(b->run_secs.value_or({})); + return a_time < b_time; + } +}; + +// TODO(tvm-team): Currently, we strictly treat each target separately. +// Since not every option in the target matters, this might be the overkill. +// Revisit this when we have better approach with target equality check. +inline std::string get_database_key(int workload_idx, Target target) { + return std::to_string(workload_idx) + "/" + target->str(); +} + +/*! \brief The default database implementation, which mimics two database tables with two files. + */ +class JSONDatabaseNode : public DatabaseNode { + public: + /*! \brief The path to the workload table */ + String path_workload; + /*! \brief The path to the tuning record table */ + String path_tuning_record; + /*! \brief The path to the measurement table */ + String path_measurement_record; + /*! \brief All the workloads in the database */ + std::unordered_map + workloads2idx_; + /*! \brief All the tuning records in the database */ + std::unordered_map> + tuning_records_; + + /*! \brief Measurement logs in the database */ + std::unordered_map> measurement_records_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("path_workload", &path_workload); + v->Visit("path_tuning_record", &path_tuning_record); + v->Visit("path_measurement_record", &path_measurement_record); + // `workloads2idx_` is not visited + // `tuning_records_` is not visited + // `measurement_records_` is not visited + } + + static constexpr const char* _type_key = "relax.tuning_api.JSONDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); + + public: + bool HasWorkload(const IRModule& mod) { + return workloads2idx_.find(meta_schedule::Workload(mod, tvm::StructuralHash()(mod))) != + workloads2idx_.end(); + } + + bool HasMeasurementRecord(const meta_schedule::Workload& workload, const Target& target) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + return measurement_records_.count(key) > 0; + } + + bool HasTuningRecord(const meta_schedule::Workload& workload, const Target& target) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + return tuning_records_.count(key) > 0; + } + + meta_schedule::Workload CommitWorkload(const IRModule& mod) { + // Try to insert `mod` into `workloads_` + decltype(this->workloads2idx_)::iterator it; + bool inserted = false; + std::tie(it, inserted) = + this->workloads2idx_.emplace(meta_schedule::Workload(mod, tvm::StructuralHash()(mod)), -1); + meta_schedule::Workload workload = it->first; + // If `mod` is new in `workloads2idx_`, append it to the workload file + if (inserted) { + it->second = static_cast(this->workloads2idx_.size()) - 1; + meta_schedule::JSONFileAppendLine(this->path_workload, + meta_schedule::JSONDumps(workload->AsJSON())); + } + return it->first; + } + + void CommitMeasurementRecord(const meta_schedule::Workload& workload, const Target& target, + const Array& run_secs) { + int workload_idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(workload_idx, target); + + if (measurement_records_[key].size() == 0) { + measurement_records_[key] = run_secs; + meta_schedule::JSONFileAppendLine(this->path_measurement_record, + meta_schedule::JSONDumps(Array{ + Integer(workload_idx), target->Export(), + run_secs // + })); + } else { + LOG(WARNING) << "Measurement record for " << key + << " already exists. Use the existing one instead."; + } + } + + void CommitTuningRecord(const meta_schedule::Workload& workload, const Target& target, + const TuningRecord& record) { + int workload_idx = this->workloads2idx_.at(workload); + // There may exist multiple tuning records (with different traces) for a single key pair. + std::string key = get_database_key(workload_idx, target); + this->tuning_records_[key].insert(record); + + meta_schedule::JSONFileAppendLine( + this->path_tuning_record, meta_schedule::JSONDumps(Array{ + Integer(workload_idx), target->Export(), record->AsJSON()})); + } + + Array GetTopK(const meta_schedule::Workload& workload, const Target& target, + int top_k) { + CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; + if (top_k == 0) { + return {}; + } + Array results; + results.reserve(top_k); + int counter = 0; + int idx = this->workloads2idx_.at(workload); + std::string key = get_database_key(idx, target); + for (const TuningRecord& record : this->tuning_records_[key]) { + results.push_back(record); + if (++counter == top_k) { + break; + } + } + + return results; + } + + Array GetMeasurementRecord(const meta_schedule::Workload& workload, + const Target target) { + int workload_idx = this->workloads2idx_.at(workload); + return this->measurement_records_[get_database_key(workload_idx, target)]; + } +}; + +Database Database::JSONDatabase(String path_workload, String path_tuning_record, + String path_measurement_record, bool allow_missing) { + int num_threads = std::thread::hardware_concurrency(); + ObjectPtr n = make_object(); + // Load `n->workloads2idx_` from `path_workload` + std::vector workloads; + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_workload, num_threads, allow_missing); + int n_objs = json_objs.size(); + n->workloads2idx_.reserve(n_objs); + workloads.reserve(n_objs); + for (int i = 0; i < n_objs; ++i) { + meta_schedule::Workload workload = meta_schedule::Workload::FromJSON(json_objs[i]); + n->workloads2idx_.emplace(workload, i); + workloads.push_back(workload); + } + } + // Load `n->tuning_records_` from `path_tuning_record` + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_tuning_record, num_threads, allow_missing); + + std::vector workload_idxs; + std::vector targets; + std::vector records; + int size = json_objs.size(); + workload_idxs.resize(size, -1); + targets.resize(size, Target{nullptr}); + records.resize(size, TuningRecord{nullptr}); + support::parallel_for_dynamic( + 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { + const ObjectRef& json_obj = json_objs[task_id]; + try { + const ArrayNode* arr = json_obj.as(); + ICHECK_EQ(arr->size(), 3); + workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); + targets[task_id] = Target(Downcast>(arr->at(1))); + records[task_id] = TuningRecord::FromJSON(arr->at(2)); + } catch (std::runtime_error& e) { + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + }); + + for (int i = 0; i < size; i++) { + std::string key = get_database_key(workload_idxs[i], targets[i]); + n->tuning_records_[key].insert(records[i]); + } + } + + // Load `n->measuremet_log` from `path_measurement_record` + { + std::vector json_objs = + meta_schedule::JSONFileReadLines(path_measurement_record, num_threads, allow_missing); + std::vector workload_idxs; + std::vector targets; + std::vector> measurements; + int size = json_objs.size(); + workload_idxs.resize(size, -1); + targets.resize(size, Target{nullptr}); + measurements.resize(size, Array({})); + support::parallel_for_dynamic( + 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { + const ObjectRef& json_obj = json_objs[task_id]; + try { + const ArrayNode* arr = json_obj.as(); + ICHECK_EQ(arr->size(), 3); + workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); + targets[task_id] = Target(Downcast>(arr->at(1))); + measurements[task_id] = meta_schedule::AsFloatArray(arr->at(2)); + } catch (std::runtime_error& e) { + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + }); + for (int i = 0; i < size; i++) { + n->measurement_records_[get_database_key(workload_idxs[i], targets[i])] = measurements[i]; + } + } + + n->path_workload = path_workload; + n->path_tuning_record = path_tuning_record; + n->path_measurement_record = path_measurement_record; + return Database(n); +} + +/**************** FFI ****************/ +TVM_REGISTER_NODE_TYPE(TuningRecordNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecord") + .set_body_typed([](Trace trace, Optional> run_secs) { + return TuningRecord(trace, run_secs); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordAsJSON") + .set_body_method(&TuningRecordNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); + +TVM_REGISTER_OBJECT_TYPE(DatabaseNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasWorkload") + .set_body_method(&DatabaseNode::HasWorkload); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasMeasurementRecord") + .set_body_method(&DatabaseNode::HasMeasurementRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasTuningRecord") + .set_body_method(&DatabaseNode::HasTuningRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitMeasurementRecord") + .set_body_method(&DatabaseNode::CommitMeasurementRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitWorkload") + .set_body_method(&DatabaseNode::CommitWorkload); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitTuningRecord") + .set_body_method(&DatabaseNode::CommitTuningRecord); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetTopK") + .set_body_method(&DatabaseNode::GetTopK); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetMeasurementRecord") + .set_body_method(&DatabaseNode::GetMeasurementRecord); + +TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseJSONDatabase").set_body_typed(Database::JSONDatabase); +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/tuning_api/primitives.cc b/src/relax/transform/tuning_api/primitives.cc new file mode 100644 index 0000000000..ef4a3d41bd --- /dev/null +++ b/src/relax/transform/tuning_api/primitives.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/tuning_api/primitives.cc + * \brief Primitives of tuning APIs. + */ + +#include + +#include "../../../meta_schedule/utils.h" +namespace tvm { +namespace relax { + +Choice::Choice(String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args) { + ObjectPtr n = make_object(); + n->transform_func_key = std::move(transform_func_key); + n->transform_func_args = std::move(transform_func_args); + n->constr_func_key = std::move(constr_func_key); + n->constr_func_args = std::move(constr_func_args); + data_ = std::move(n); +} + +// TODO(sunggg): Currently, it only supports an array of primitive data types. +ObjectRef ChoiceNode::AsJSON() const { + Array json_transfrom_args, json_constr_args; + for (ObjectRef arg : this->transform_func_args) { + std::string json_arg = tvm::SaveJSON(arg); + std::string b64_arg = meta_schedule::Base64Encode(json_arg); + json_transfrom_args.push_back(String(b64_arg)); + } + for (ObjectRef arg : this->constr_func_args) { + std::string json_arg = tvm::SaveJSON(arg); + std::string b64_arg = meta_schedule::Base64Encode(json_arg); + json_constr_args.push_back(String(b64_arg)); + } + return Array{ + this->transform_func_key, + json_transfrom_args, + this->constr_func_key, + json_constr_args, + }; +} + +Choice Choice::FromJSON(const ObjectRef& json) { + // Parse `json` into `choice` + String transform_func_key, constr_func_key; + Array transform_func_args, constr_func_args; + try { + const ArrayNode* arr = json.as(); + ICHECK(arr && arr->size() == 4); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + const auto* arr2 = arr->at(2).as(); + const auto* arr3 = arr->at(3).as(); + ICHECK(arr0 && arr1 && arr2 && arr3); + transform_func_key = GetRef(arr0); + { + transform_func_args.reserve(arr1->size()); + for (const ObjectRef& elem : *arr1) { + String b64_arg = Downcast(elem); + std::string json_arg = meta_schedule::Base64Decode(b64_arg); + ObjectRef arg = LoadJSON(json_arg); + transform_func_args.push_back(arg); + } + } + constr_func_key = GetRef(arr2); + { + constr_func_args.reserve(arr3->size()); + for (const ObjectRef& elem : *arr3) { + String b64_arg = Downcast(elem); + std::string json_arg = meta_schedule::Base64Decode(b64_arg); + ObjectRef arg = LoadJSON(json_arg); + constr_func_args.push_back(arg); + } + } + } catch (const tvm::Error& e) { + LOG(FATAL) + << "ValueError: The json entry of a choice should contain a set of two strings, but gets: " + << json; + throw; + } + return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); +} + +Knob::Knob(String name, Map choices) { + ObjectPtr n = make_object(); + n->name = std::move(name); + n->choices = std::move(choices); + data_ = std::move(n); +} + +ObjectRef KnobNode::AsJSON() const { + Map json_choices; + for (auto const& x : choices) { + json_choices.Set(x.first, x.second->AsJSON()); + } + return Array{ + /* 0: name */ std::move(name), + /* 1: choices */ std::move(json_choices), + }; +} + +Knob Knob::FromJSON(const ObjectRef& json) { + // Parse `json` into `name` and `choices` + String name; + Map choices; + try { + const ArrayNode* arr = json.as(); + ICHECK(arr && arr->size() == 2); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + ICHECK(arr0 && arr1); + name = GetRef(arr0); + for (auto const& x : GetRef>(arr1)) { + String decision = x.first; + Choice choice = Choice::FromJSON(x.second); + choices.Set(decision, choice); + } + } catch (const tvm::Error& e) { + LOG(FATAL) + << "ValueError: The json entry of a choice should contain a set of two strings, but gets: " + << json; + throw; + } + return Knob(name, choices); +} + +Trace::Trace() { data_ = make_object(); } + +Trace::Trace(IRModule in_mod, Array knobs, Array decisions) { + ICHECK(knobs.size() == decisions.size()) << "Size of knobs and decisions should match"; + // Deep-copy IRModule + auto func_deepcopy = runtime::Registry::Get("relax.tuning_api.deepcopy_irmodule"); + ICHECK(func_deepcopy); + IRModule out_mod = (*func_deepcopy)(in_mod); + // Apply the decision history if provided + int size = knobs.size(); + for (int i = 0; i < size; i++) { + out_mod = knobs[i]->Apply(out_mod, decisions[i]); + } + + ObjectPtr n = make_object(); + n->in_mod = std::move(in_mod); + n->out_mod = std::move(out_mod); + n->knobs = std::move(knobs); + n->decisions = std::move(decisions); + n->size = std::move(size); + data_ = std::move(n); +} + +ObjectRef TraceNode::AsJSON(bool include_in_mod) const { + ICHECK(this->Verify()) << "Trace should be valid"; + + Array json_knobs; + Array json_decisions; + + int size = this->size; + json_knobs.reserve(size); + json_decisions.reserve(size); + + for (int i = 0; i < size; i++) { + const Knob& knob = this->knobs[i]; + const String& decision = this->decisions[i]; + + json_knobs.push_back(knob->AsJSON()); + json_decisions.push_back(decision); + } + if (include_in_mod) { + std::string json_mod = tvm::SaveJSON(this->in_mod); + std::string b64_mod = meta_schedule::Base64Encode(json_mod); + return Array{json_knobs, json_decisions, String(b64_mod)}; + } else { + return Array{json_knobs, json_decisions}; + } +} + +Trace Trace::FromJSON(const ObjectRef& json) { + // Parse `json` into `trace` + IRModule in_mod; + Array knobs; + Array decisions; + try { + const ArrayNode* arr = json.as(); + // A trace will have 2 or 3 entries depending on `include_irmod` parameter. + ICHECK(arr && (arr->size() == 2 || arr->size() == 3)); + + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); + ICHECK(arr0 && arr1); + + for (const ObjectRef& elem : *arr0) { + knobs.push_back(Knob::FromJSON(elem)); + } + + for (const ObjectRef& elem : *arr1) { + decisions.push_back(Downcast(elem)); + } + + // When `include_irmod = true` + if (arr->size() == 3) { + const auto* arr2 = arr->at(2).as(); + String b64_mod = GetRef(arr2); + ICHECK(arr2); + std::string json_mod = meta_schedule::Base64Decode(b64_mod); + in_mod = Downcast(LoadJSON(json_mod)); + } + } catch (const tvm::Error& e) { + LOG(FATAL) << "ValueError: Malformed Trace format - " << json; + throw; + } + return Trace(in_mod, knobs, decisions); +} + +/**************** FFI ****************/ +TVM_REGISTER_NODE_TYPE(ChoiceNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Choice") + .set_body_typed([](String transform_func_key, Array transform_func_args, + String constr_func_key, Array constr_func_args) { + return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceAsJSON").set_body_method(&ChoiceNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceFromJSON").set_body_typed(Choice::FromJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetTransformFunc") + .set_body_method(&ChoiceNode::GetTransformFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetConstrFunc") + .set_body_method(&ChoiceNode::GetConstrFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceApplyTransformFunc") + .set_body_method(&ChoiceNode::ApplyTransformFunc); +TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceCheckConstr") + .set_body_method(&ChoiceNode::CheckConstr); + +TVM_REGISTER_NODE_TYPE(KnobNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Knob") + .set_body_typed([](String name, Map choices) { return Knob(name, choices); }); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobAsJSON").set_body_method(&KnobNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobFromJSON").set_body_typed(Knob::FromJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobIsValidDecision") + .set_body_method(&KnobNode::IsValidDecision); +TVM_REGISTER_GLOBAL("relax.tuning_api.KnobApply").set_body_method(&KnobNode::Apply); + +TVM_REGISTER_NODE_TYPE(TraceNode); +TVM_REGISTER_GLOBAL("relax.tuning_api.Trace") + .set_body_typed([](IRModule in_mod, Array knobs, Array decisions) { + return Trace(in_mod, knobs, decisions); + }); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceVerify").set_body_method(&TraceNode::Verify); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAdd").set_body_method(&TraceNode::Add); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetPerf").set_body_method(&TraceNode::SetPerf); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetOutMod") + .set_body_method(&TraceNode::SetOutMod); + +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAsJSON").set_body_method(&TraceNode::AsJSON); +TVM_REGISTER_GLOBAL("relax.tuning_api.TraceFromJSON").set_body_typed(Trace::FromJSON); +} // namespace relax +} // namespace tvm diff --git a/src/relax/usmp/analysis/extract_buffer_info.cc b/src/relax/usmp/analysis/extract_buffer_info.cc new file mode 100644 index 0000000000..bbc1e06992 --- /dev/null +++ b/src/relax/usmp/analysis/extract_buffer_info.cc @@ -0,0 +1,904 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/usmp/analysis/extract_buffer_info.cc + * + * \brief This analysis pass consumes a TIR IRModule with a main function + * that defines a ordering in the callees to operators and produces BufferInfo + * objects that contains information about tir.allocate nodes and liveness + * conflicts between other tir.allocate nodes. + */ +#include +#include +#include +#include +#include + +#include + +#include "../../../runtime/thread_storage_scope.h" +#include "tvm/relax/attrs/memory.h" +#include "tvm/relax/expr_functor.h" +#include "tvm/tir/builtin.h" +#include "tvm/tir/function.h" +#include "tvm/tir/stmt.h" +#include "tvm/tir/stmt_functor.h" + +namespace tvm { + +namespace tir::usmp { +class TIRInfoExtractor; +} + +namespace relax::usmp { +class RelaxInfoExtractor; +} + +class BufferInfoExtractor; + +namespace usmp { +/*! + * \brief The class to keep buffer information used by this pass. + * + * The Relax and TIR visitors would initiate the traversal from the main Relax + * function and visit into the operator PrimFuncs. They will + * create unique BufferInfo objects for each Relax/TIR allocation. + * + * Every time the buffer variable of an allocation is referenced + * it will be recorded using the stmt index. However, note that + * the same buffer variable could be references multiple times + * from different calls. Thereafter, a sweep is done on all the + * BufferInfo objects using the per-call liveness events. In the sweep, + * The BufferInfo objects that are live together will be recorded as + * mutual conflicts of each other. + */ +class BufferInfoPassData { + using BufferInfo = tir::usmp::BufferInfo; + using Call = tir::Call; + using PrimFunc = tir::PrimFunc; + using For = tir::For; + using Allocate = tir::Allocate; + using AllocateConst = tir::AllocateConst; + using Function = relax::Function; + + /*! + * \brief Maintains the mapping of BufferInfo to their associated TIR Statements or Relax Expr. + */ + Map buffer_info_map_; + /*! + * \brief Records the order of calls in the main for stability. + */ + std::vector call_order_; + /*! + * \brief Lookup to avoid adding duplicates to `call_order_`. + */ + std::unordered_set call_order_contents_; + /*! + * \brief Records first access in-terms of TIR Stmts/Relax Expr to each buffer per call + * + * This is because multiple calls could happen to the same PrimFunc. + */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + buffer_info_start_stmt_idx_; + /*! + * \brief Records last access in-terms of TIR Stmts/Relax Expr to each buffer per call + * + * This is because multiple calls could happen to the same PrimFunc. + */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + buffer_info_end_stmt_idx_; + /*! + * \brief This structure contains information regarding a TIR Allocate node / Relax call node to + * alloc_tensor. + */ + struct AllocateInfo { + runtime::ObjectRef Allocate; + BaseFunc func; + BaseExpr call; + }; + + /*! + * \brief Maintains the mapping of TIR buffer variable / Relax Var to their allocate infos to + * ensure that only one BufferInfo object is created. + */ + std::unordered_map allocate_infos; + /*! + * \brief Indicates a count of stmts visited so far to use as a metric of liveness + */ + int current_stmt_idx_ = 0; + /*! + * \brief This structure is supposed to contain information around the scope + * the visitor is currently in. + */ + struct ScopeInfo { + /*! + * \brief We need to record access per call + */ + BaseExpr call; + /*! + * \brief Having access to PrimFunc/RelaxFunc metadata is useful + */ + BaseFunc func; + /*! + * \brief We currently support only serial for loops. Therefore + * need to know what kind of for loop the visitor is in. Only used when visiting PrimFuncs. + */ + For for_loop; + /*! + * \brief We record the live TIR allocate_nodes and Relax allocate Expr because once in loops + * the liveness range has to be extended to the whole of the nested + * loops structure. + */ + std::unordered_set allocate_nodes; + /* + * \brief We record the live allocate_const_nodes because once in loops + * the liveness range has to be extended to the whole of the nested + * loops structure. + */ + std::unordered_set allocate_const_nodes; + /*! + * \brief This is recorded to extend the liveness of all allocates within + * nested loop structure. Only used for PrimFuncs. + */ + Integer initial_stmt_of_the_nested_loops; + }; + std::stack scope_stack_; + + /*! + * \brief A liveness event tracks when + * traversing the tir.Stmts/Relax.Expr where allocations + * begin or cease to be Live. This particular struct + * is used to solve interval overlap problem using + * a sweep-line algorithm. For that, we need to record + * where the liveness event occurred in a chronological + * order. + */ + enum LivenessEventType { START = 0, END = 1 }; + struct LivenessEvent { + size_t tick; + LivenessEventType le_type; + BufferInfo buffer_info; + bool operator==(const LivenessEvent& other) { + if (tick == other.tick && le_type == other.le_type && buffer_info == other.buffer_info) { + return true; + } + return false; + } + }; + /*! + * \brief We need to create unique buffer name is the same name is used in + * two allocate nodes for clarity for memory planning algorithms. + */ + std::string GetUniqueBufferName(std::string name); + + /*! + * \brief This is per buffer name counter to aid the generating the above + * unique name. + */ + std::unordered_map buffer_names; + /*! + * \brief The Relax main function calls to external functions to be able to + * support BYOC. Therefore, this Map records functions that are present + * in the IRModule by name/ + */ + Map functions_; + + /*! + * \brief The IRModule being analyzed. + */ + IRModule module_; + + friend class tvm::BufferInfoExtractor; + friend class tir::usmp::TIRInfoExtractor; + friend class relax::usmp::RelaxInfoExtractor; +}; + +} // namespace usmp + +namespace tir { +namespace usmp { + +class TIRInfoExtractor : public StmtExprVisitor { + using BufferInfoPassData = tvm::usmp::BufferInfoPassData; + + public: + explicit TIRInfoExtractor(BufferInfoPassData& pass_data) : pass_data_(pass_data) {} + + void VisitPrimFunc(const PrimFunc& func, const BaseExpr& call); + void UpdateAliases(const Array& args, const PrimFunc& func); + + private: + void VisitStmt(const Stmt& n) override; + void VisitStmt_(const AllocateNode* op) override; + void VisitStmt_(const AllocateConstNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const BufferLoadNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; + void VisitStmt_(const ForNode* op) override; + + void RecordAllocateNodeInfo(const AllocateNode* op); + void RecordAllocateConstNodeInfo(const AllocateConstNode* op); + + BufferInfoPassData& pass_data_; +}; + +void TIRInfoExtractor::VisitStmt(const Stmt& n) { + pass_data_.current_stmt_idx_ += 1; + StmtExprVisitor::VisitStmt(n); +} + +void TIRInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) { + auto size_bytes = tir::usmp::CalculateExtentsSize(op); + // We only statically memory plan allocates with known + // compile time sizes. + if (size_bytes.defined()) { + if (pass_data_.allocate_infos.find(op->buffer_var) == pass_data_.allocate_infos.end()) { + // By default, the core compiler is assumed to attach the a default pool to each allocate. + ICHECK(op->annotations.count(tir::usmp::kPoolCandidatesAllocateAttr)) + << "Every statically sized allocate node needs an pool candidate attribute"; + auto pool_candidates = + Downcast>(op->annotations[tir::usmp::kPoolCandidatesAllocateAttr]); + + ICHECK(pool_candidates.size() > 0) + << "The AssignPoolInfo pass should at least attach a single PoolInfo. If there were no " + "user-given arguments for memory pools, the default behaviour is a single size " + "un-restricted pool is assigned"; + PrimFunc func = Downcast(pass_data_.scope_stack_.top().func); + Optional executor_config = + pass_data_.module_->GetAttr(tvm::attr::kExecutor); + Integer workspace_alignment = 16; + if (executor_config) { + workspace_alignment = + executor_config.value()->GetAttr("workspace-byte-alignment").value_or(16); + } + + BufferInfoKind bi_kind = BufferInfoKind::kIntermediate; + String buffer_info_name = op->buffer_var->name_hint; + if (op->annotations.find(kInputTensorAllocate) != op->annotations.end()) { + bi_kind = BufferInfoKind::kInput; + // using original input name instead of the buffer_var name + // because this name will be used in the lowering to convey + // the pool allocation. + buffer_info_name = Downcast(op->annotations[kInputTensorAllocate]); + } else if (op->annotations.find(kOutputTensorAllocate) != op->annotations.end()) { + bi_kind = BufferInfoKind::kOutput; + // using original output name instead of the buffer_var name + // because this name will be used in the lowering to convey + // the pool allocation. + buffer_info_name = Downcast(op->annotations[kOutputTensorAllocate]); + } + auto buffer_info = BufferInfo(pass_data_.GetUniqueBufferName(buffer_info_name), size_bytes, + pool_candidates, workspace_alignment, bi_kind); + auto allocate = GetRef(op); + pass_data_.allocate_infos[op->buffer_var] = BufferInfoPassData::AllocateInfo{ + allocate, Downcast(pass_data_.scope_stack_.top().func), + pass_data_.scope_stack_.top().call}; + pass_data_.buffer_info_map_.Set(buffer_info, allocate); + } else { + // Update the allocate info with the latest call + BufferInfoPassData::AllocateInfo ai = pass_data_.allocate_infos[op->buffer_var]; + ai.call = pass_data_.scope_stack_.top().call; + pass_data_.allocate_infos[op->buffer_var] = ai; + } + } +} + +void TIRInfoExtractor::VisitStmt_(const AllocateNode* op) { + using ScopeInfo = BufferInfoPassData::ScopeInfo; + ScopeInfo& current_scope_info = pass_data_.scope_stack_.top(); + const auto& type = Downcast(op->buffer_var->type_annotation); + const auto& storage_scope = runtime::StorageScope::Create(type->storage_scope); + + // If the allocate is in a for loop, USMP currently only looks at serial for loops. + // If its not a serial for loop, then memory planner will omit them in the current memory planning + // process leaving them to as tir.allocate nodes for codegen. Additionally, the USMP can only work + // with buffers that have global storage_scope + + if (storage_scope.rank == runtime::StorageRank::kGlobal) { + if (!current_scope_info.for_loop.defined()) { + RecordAllocateNodeInfo(op); + } else if (current_scope_info.for_loop.defined() && + current_scope_info.for_loop->kind == ForKind::kSerial) { + RecordAllocateNodeInfo(op); + } + } + StmtExprVisitor::VisitStmt(op->body); + current_scope_info.allocate_nodes.erase(GetRef(op)); +} + +void TIRInfoExtractor::VisitStmt_(const AllocateConstNode* op) { + using ScopeInfo = BufferInfoPassData::ScopeInfo; + ScopeInfo& current_scope_info = pass_data_.scope_stack_.top(); + RecordAllocateConstNodeInfo(op); + StmtExprVisitor::VisitStmt(op->body); + current_scope_info.allocate_const_nodes.erase(GetRef(op)); +} + +void TIRInfoExtractor::RecordAllocateConstNodeInfo(const AllocateConstNode* op) { + if (!op->annotations.count(kPoolCandidatesAllocateAttr)) { + return; + } + Integer size_bytes = CalculateExtentsSize(op); + ICHECK(size_bytes.defined()) << "constant node size should be defined"; + const auto& buffer_var = op->buffer_var; + if (pass_data_.allocate_infos.find(buffer_var) == pass_data_.allocate_infos.end()) { + // By default, the core compiler is assumed to attach the a default pool to each allocate. + ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) + << "Every statically sized allocate node needs an pool candidate attribute"; + auto pool_candidates = Downcast>(op->annotations[kPoolCandidatesAllocateAttr]); + ICHECK(pool_candidates.size() > 0) + << "The core compiler should at least attach a single PoolInfo. If there were no " + "user-given arguments for memory pools, the default behaviour is a single size " + "un-restricted pool is assigned"; + PrimFunc func = Downcast(pass_data_.scope_stack_.top().func); + Optional executor_config = + pass_data_.module_->GetAttr(tvm::attr::kExecutor); + Integer alignment = 16; + if (executor_config) { + alignment = + executor_config.value()->GetAttr("constant-byte-alignment").value_or(alignment); + } + auto buffer_info = BufferInfo(pass_data_.GetUniqueBufferName(buffer_var->name_hint), size_bytes, + pool_candidates, alignment); + auto allocate = GetRef(op); + pass_data_.allocate_infos[buffer_var] = BufferInfoPassData::AllocateInfo{ + allocate, Downcast(pass_data_.scope_stack_.top().func), + pass_data_.scope_stack_.top().call}; + pass_data_.buffer_info_map_.Set(buffer_info, allocate); + } else { + // Update the allocate info with the latest call + BufferInfoPassData::AllocateInfo ai = pass_data_.allocate_infos[buffer_var]; + ai.call = pass_data_.scope_stack_.top().call; + pass_data_.allocate_infos[buffer_var] = ai; + } +} + +void TIRInfoExtractor::VisitStmt_(const ForNode* op) { + using ScopeInfo = BufferInfoPassData::ScopeInfo; + ScopeInfo si{pass_data_.scope_stack_.top().call, + pass_data_.scope_stack_.top().func, + GetRef(op), + pass_data_.scope_stack_.top().allocate_nodes, + pass_data_.scope_stack_.top().allocate_const_nodes, + pass_data_.scope_stack_.top().initial_stmt_of_the_nested_loops}; + if (!pass_data_.scope_stack_.top().initial_stmt_of_the_nested_loops.defined()) { + si.initial_stmt_of_the_nested_loops = Integer(pass_data_.current_stmt_idx_); + } + BaseExpr current_call = pass_data_.scope_stack_.top().call; + auto current_func = pass_data_.scope_stack_.top().func; + pass_data_.scope_stack_.push(si); + StmtExprVisitor::VisitStmt_(op); + // Extending the liveness to beginning of for-loop next and end of the current for-loop + for (const runtime::ObjectRef& ref : pass_data_.scope_stack_.top().allocate_nodes) { + BufferInfoPassData::AllocateInfo ai; + if (ref->IsInstance()) { + auto expr = runtime::Downcast(ref); + ai = pass_data_.allocate_infos[expr]; + } else if (ref->IsInstance()) { + auto allocate = runtime::Downcast(ref); + ai = pass_data_.allocate_infos[allocate->buffer_var]; + } + auto allocate = ref; + BaseExpr update_call = current_call; + // If the allocate does not belong to current func + // We need to update the call to which the allocate belongs to + if (ai.func != current_func) { + update_call = ai.call; + } + if (pass_data_.scope_stack_.top().initial_stmt_of_the_nested_loops->value < + pass_data_.buffer_info_start_stmt_idx_[update_call][allocate].IntValue()) { + pass_data_.buffer_info_start_stmt_idx_[update_call].Set( + allocate, pass_data_.scope_stack_.top().initial_stmt_of_the_nested_loops->value); + } + if (pass_data_.current_stmt_idx_ > + pass_data_.buffer_info_end_stmt_idx_[update_call][allocate].IntValue()) { + pass_data_.buffer_info_end_stmt_idx_[update_call].Set(allocate, pass_data_.current_stmt_idx_); + } + } + pass_data_.scope_stack_.pop(); +} + +void TIRInfoExtractor::VisitExpr_(const BufferLoadNode* op) { + this->VisitExpr(op->buffer->data); + StmtExprVisitor::VisitExpr_(op); +} + +void TIRInfoExtractor::VisitStmt_(const BufferStoreNode* op) { + this->VisitExpr(op->buffer->data); + StmtExprVisitor::VisitStmt_(op); +} + +void TIRInfoExtractor::VisitExpr_(const VarNode* op) { + auto var = GetRef(op); + auto current_call = pass_data_.scope_stack_.top().call; + auto current_func = pass_data_.scope_stack_.top().func; + if (pass_data_.allocate_infos.count(var)) { + auto allocate = pass_data_.allocate_infos[var].Allocate; + auto allocate_func = pass_data_.allocate_infos[var].func; + BaseExpr update_call = current_call; + if (allocate_func != current_func) { + // If the allocate node does not belong to the current primfunc. + // It's access should be reported to the call to PrimFunc that + // Allocate belong to. + update_call = pass_data_.allocate_infos[var].call; + } + if (pass_data_.buffer_info_start_stmt_idx_[update_call].count(allocate) == 0) { + pass_data_.buffer_info_start_stmt_idx_[update_call].Set(allocate, + pass_data_.current_stmt_idx_); + } + pass_data_.buffer_info_end_stmt_idx_[update_call].Set(allocate, pass_data_.current_stmt_idx_); + + BufferInfoPassData::ScopeInfo& currect_scope_info = pass_data_.scope_stack_.top(); + if (currect_scope_info.for_loop.defined()) { + if (allocate->IsInstance()) { + currect_scope_info.allocate_nodes.insert(Downcast(allocate)); + } else if (allocate->IsInstance()) { + currect_scope_info.allocate_const_nodes.insert(Downcast(allocate)); + } else if (allocate->IsInstance()) { + currect_scope_info.allocate_nodes.insert(Downcast(allocate)); + } else { + LOG(FATAL) << "Handling of " << allocate->GetTypeKey() << " is not implemented"; + } + } + } + StmtExprVisitor::VisitExpr_(op); +} + +Array static GetMatchedBuffers(const PrimFunc& func) { + Array buffer_vars; + for (unsigned int i = 0; i < func->params.size() - 1; i++) { + Var param = func->params[i]; + buffer_vars.push_back(func->buffer_map[param]->data); + } + Var last_param = func->params.back(); + // Checks whether last var is present in the buffer map + // because it could be the resource handle + if (func->buffer_map.find(last_param) != func->buffer_map.end()) { + buffer_vars.push_back(func->buffer_map[last_param]->data); + } + return buffer_vars; +} + +void TIRInfoExtractor::UpdateAliases(const Array& args, const PrimFunc& func) { + auto param_buffers = GetMatchedBuffers(func); + // Last var could be a resource handle that does not have a Buffer + ICHECK(args.size() == param_buffers.size() || args.size() - 1 == param_buffers.size()); + for (size_t i = 0; i < param_buffers.size(); i++) { + auto arg = args[i]; + auto param_buf = param_buffers[i]; + // If tir.allocates are passed in to functions + // The function params are re-directed to point + // to the original allocate + if (arg->IsInstance()) { + auto load = Downcast(arg); + if (pass_data_.allocate_infos.count(load->buffer_var)) { + pass_data_.allocate_infos[param_buf] = pass_data_.allocate_infos[load->buffer_var]; + } + } else if (arg->IsInstance()) { + auto var = Downcast(arg); + if (pass_data_.allocate_infos.count(var)) { + pass_data_.allocate_infos[param_buf] = pass_data_.allocate_infos[var]; + } + } else if (arg->IsInstance()) { + auto var = Downcast(arg); + if (pass_data_.allocate_infos.count(var)) { + pass_data_.allocate_infos[param_buf] = pass_data_.allocate_infos[var]; + } + } + } +} + +void TIRInfoExtractor::VisitPrimFunc(const PrimFunc& func, const BaseExpr& call) { + BufferInfoPassData::ScopeInfo si{call, + func, + pass_data_.scope_stack_.top().for_loop, + pass_data_.scope_stack_.top().allocate_nodes, + pass_data_.scope_stack_.top().allocate_const_nodes, + pass_data_.scope_stack_.top().initial_stmt_of_the_nested_loops}; + if (pass_data_.call_order_contents_.count(call) == 0) { + pass_data_.call_order_contents_.insert(call); + pass_data_.call_order_.push_back(call); + } + pass_data_.scope_stack_.push(si); + this->VisitStmt(func->body); + pass_data_.scope_stack_.pop(); +} + +void TIRInfoExtractor::VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { + StringImm func_name = Downcast(op->args[0])->value; + if (pass_data_.functions_.find(func_name->value) != pass_data_.functions_.end()) { + auto func = pass_data_.functions_.at(func_name->value); + auto actual_args = Array(op->args.begin() + 1, op->args.end()); + this->UpdateAliases(actual_args, Downcast(func)); + VisitPrimFunc(Downcast(func), GetRef(op)); + return; + } + } + if (op->op->IsInstance()) { + auto func = Downcast(op->op); + auto actual_args = Array(op->args.begin(), op->args.end()); + this->UpdateAliases(actual_args, func); + VisitPrimFunc(func, GetRef(op)); + return; + } + StmtExprVisitor::VisitExpr_(op); +} + +} // namespace usmp +} // namespace tir + +namespace relax { +namespace usmp { + +class RelaxInfoExtractor : public relax::ExprVisitor { + using BufferInfoPassData = tvm::usmp::BufferInfoPassData; + + public: + explicit RelaxInfoExtractor(BufferInfoPassData& pass_data) : pass_data_(pass_data) {} + + void VisitRelaxFunc(const Function& func, const Call& call); + + private: + void VisitExpr(const Expr& expr) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const CallNode* op) override; + + void VisitBinding_(const VarBindingNode* binding); + + void VisitAllocTensorVarBinding(const VarBindingNode* op); + void RecordAllocateNodeInfo(const VarBindingNode* op); + + BufferInfoPassData& pass_data_; +}; + +void RelaxInfoExtractor::VisitExpr(const Expr& expr) { + pass_data_.current_stmt_idx_ += 1; + ExprVisitor::VisitExpr(expr); +} + +void RelaxInfoExtractor::VisitExpr_(const VarNode* op) { + auto var = GetRef(op); + + BaseExpr current_call = pass_data_.scope_stack_.top().call; + BaseFunc current_func = pass_data_.scope_stack_.top().func; + if (pass_data_.allocate_infos.count(var)) { + auto allocate = pass_data_.allocate_infos[var].Allocate; + auto allocate_func = pass_data_.allocate_infos[var].func; + BaseExpr update_call = current_call; + if (allocate_func != current_func) { + // If the allocate node does not belong to the current func, + // the access should be reported to the call to the func that + // the node belongs to. + update_call = pass_data_.allocate_infos[var].call; + } + if (pass_data_.buffer_info_start_stmt_idx_[update_call].count(allocate) == 0) { + pass_data_.buffer_info_start_stmt_idx_[update_call].Set(allocate, + pass_data_.current_stmt_idx_); + } + pass_data_.buffer_info_end_stmt_idx_[update_call].Set(allocate, pass_data_.current_stmt_idx_); + + BufferInfoPassData::ScopeInfo& currect_scope_info = pass_data_.scope_stack_.top(); + if (currect_scope_info.for_loop.defined()) { + currect_scope_info.allocate_nodes.insert(allocate); + } + } + ExprVisitor::VisitExpr_(op); +} + +void RelaxInfoExtractor::VisitBinding_(const VarBindingNode* binding) { + auto node = GetRef(binding); + if (node->value->IsInstance()) { + auto call_node = runtime::Downcast(node->value); + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + if (call_node->op == alloc_tensor_op) { + VisitAllocTensorVarBinding(binding); + return; + } + } else if (node->value->IsInstance()) { + // Update the allocate info map with the alias. + auto aliased_var = runtime::Downcast(node->value); + if (pass_data_.allocate_infos.count(aliased_var)) { + BufferInfoPassData::AllocateInfo ai = pass_data_.allocate_infos[aliased_var]; + ai.call = pass_data_.scope_stack_.top().call; + pass_data_.allocate_infos[node->var] = ai; + } + } + ExprVisitor::VisitBinding_(binding); +} + +static Integer CalculateRelaxExtentsSize(const DataType& dtype, const Array& extents) { + size_t element_size_bytes = dtype.bytes(); + size_t num_elements = 1; + for (const auto& ext : extents) { + if (ext->IsInstance()) { + num_elements *= Downcast(ext)->value; + } + } + return Integer(num_elements * element_size_bytes); +} + +void RelaxInfoExtractor::RecordAllocateNodeInfo(const VarBindingNode* op) { + auto var_binding = runtime::GetRef(op); + // TODO(gigiblender) checked_type of relax.alloc_tensor should not be dynamic when + // constant sizes are used. + ICHECK(op->var->checked_type()->IsInstance()) + << "Expected a dynamic tensor type object"; + auto dyn_tensor_type = runtime::Downcast(op->var->checked_type()); + ICHECK(op->var->shape()->IsInstance()) << "Expected a ShapeExpr"; + auto shape_expr = runtime::Downcast(op->var->shape()); + auto size_bytes = CalculateRelaxExtentsSize(dyn_tensor_type->dtype, shape_expr->values); + if (size_bytes.defined()) { + auto var_node = op->var; + auto call_node = runtime::Downcast(op->value); + if (pass_data_.allocate_infos.find(var_node) == pass_data_.allocate_infos.end()) { + // By default, the core compiler is assumed to attach the a default pool to each allocate. + auto call_dict_attrs = call_node->attrs.as(); + auto pool_candidates = call_dict_attrs->candidate_memory_pools; + ICHECK(pool_candidates.size() > 0) + << "The AssignPoolInfo pass should at least attach a single PoolInfo. If there were no " + "user-given arguments for memory pools, the default behaviour is a single size " + "un-restricted pool is assigned"; + BaseFunc func = pass_data_.scope_stack_.top().func; + Optional executor_config = + pass_data_.module_->GetAttr(tvm::attr::kExecutor); + Integer workspace_alignment = 16; + if (executor_config) { + workspace_alignment = + executor_config.value()->GetAttr("workspace-byte-alignment").value_or(16); + } + + tir::usmp::BufferInfoKind bi_kind = tir::usmp::BufferInfoKind::kIntermediate; + String buffer_info_name = op->var->name_hint(); + auto buffer_info = + tir::usmp::BufferInfo(pass_data_.GetUniqueBufferName(buffer_info_name), size_bytes, + pool_candidates, workspace_alignment, bi_kind); + pass_data_.allocate_infos[var_node] = BufferInfoPassData::AllocateInfo{ + var_node, pass_data_.scope_stack_.top().func, pass_data_.scope_stack_.top().call}; + pass_data_.buffer_info_map_.Set(buffer_info, var_node); + } else { + // Update the allocate info with the latest call + BufferInfoPassData::AllocateInfo ai = pass_data_.allocate_infos[var_node]; + ai.call = pass_data_.scope_stack_.top().call; + pass_data_.allocate_infos[var_node] = ai; + } + } +} + +void RelaxInfoExtractor::VisitAllocTensorVarBinding(const VarBindingNode* op) { + BufferInfoPassData::ScopeInfo& current_scope_info = pass_data_.scope_stack_.top(); + RecordAllocateNodeInfo(op); + ExprVisitor::VisitBinding_(op); + current_scope_info.allocate_nodes.erase(op->var); +} + +void RelaxInfoExtractor::VisitExpr_(const CallNode* op) { + auto node = GetRef(op); + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + if (op->op == alloc_tensor_op) { + // Handled by the VarBinding visit method + return; + } + + if (op->op->IsInstance()) { + String func_name = runtime::Downcast(op->op)->global_symbol; + if (pass_data_.functions_.find(func_name) != pass_data_.functions_.end()) { + auto func = pass_data_.functions_.at(func_name); + if (func->IsInstance()) { + auto actual_args = Array(op->args.begin(), op->args.end()); + tir::usmp::TIRInfoExtractor tir_info_extractor = tir::usmp::TIRInfoExtractor(pass_data_); + tir_info_extractor.UpdateAliases(actual_args, Downcast(func)); + tir_info_extractor.VisitPrimFunc(Downcast(func), GetRef(op)); + return; + } + } + } + if (op->op->IsInstance()) { + auto func = Downcast(op->op); + ICHECK(false) << "Calls to Relax functions are not supported." << PrettyPrint(func); + } + if (op->op->IsInstance()) { + auto global_var = Downcast(op->op); + ICHECK(false) << "Calls to Relax functions are not supported: " << global_var->name_hint; + } + ExprVisitor::VisitExpr_(op); +} + +void RelaxInfoExtractor::VisitRelaxFunc(const Function& func, const Call& call) { + BufferInfoPassData::ScopeInfo si{call, + func, + pass_data_.scope_stack_.top().for_loop, + pass_data_.scope_stack_.top().allocate_nodes, + pass_data_.scope_stack_.top().allocate_const_nodes, + pass_data_.scope_stack_.top().initial_stmt_of_the_nested_loops}; + if (pass_data_.call_order_contents_.count(call) == 0) { + pass_data_.call_order_contents_.insert(call); + pass_data_.call_order_.push_back(call); + } + pass_data_.scope_stack_.push(si); + this->VisitExpr(func->body); + pass_data_.scope_stack_.pop(); +} + +} // namespace usmp +} // namespace relax + +class BufferInfoExtractor { + using BufferInfoPassData = tvm::usmp::BufferInfoPassData; + using BufferInfoAnalysis = tir::usmp::BufferInfoAnalysis; + + public: + explicit BufferInfoExtractor(const IRModule& module) { + pass_data_.module_ = module; + for (const auto& gv_func : module->functions) { + if (gv_func.second->IsInstance() || + gv_func.second->IsInstance()) { + pass_data_.functions_.Set(gv_func.first->name_hint, Downcast(gv_func.second)); + } + } + // Pushing a scope info for the initial body of the main function + pass_data_.scope_stack_.push(BufferInfoPassData::ScopeInfo()); + } + BufferInfoAnalysis operator()(const relax::Function& func); + + private: + BufferInfoPassData pass_data_; +}; + +std::string usmp::BufferInfoPassData::GetUniqueBufferName(std::string name) { + if (buffer_names.find(name) == buffer_names.end()) { + buffer_names[name] = 1; + return name; + } else { + buffer_names[name] = buffer_names[name] + 1; + return name + std::to_string(buffer_names[name]); + } +} + +tir::usmp::BufferInfoAnalysis BufferInfoExtractor::operator()(const relax::Function& main_func) { + using LivenessEvent = BufferInfoPassData::LivenessEvent; + using LivenessEventType = BufferInfoPassData::LivenessEventType; + using BufferInfo = tir::usmp::BufferInfo; + using RelaxInfoExtractor = relax::usmp::RelaxInfoExtractor; + + RelaxInfoExtractor relax_info_extractor = RelaxInfoExtractor(pass_data_); + relax_info_extractor.VisitRelaxFunc(main_func, relax::Call()); + + // Create a vector of liveness events + // associated with each BufferNodes. + std::vector le_events_timeline; + for (const auto& kv1 : pass_data_.buffer_info_map_) { + auto allocate = kv1.second; + auto buffer_info = Downcast(kv1.first); + + ICHECK(pass_data_.call_order_.size() >= pass_data_.buffer_info_end_stmt_idx_.size()); + + for (const BaseExpr& call : pass_data_.call_order_) { + Map buffer_info_starts = + pass_data_.buffer_info_start_stmt_idx_[call]; + if (buffer_info_starts.find(allocate) != buffer_info_starts.end()) { + LivenessEvent le_event_start; + le_event_start.buffer_info = buffer_info; + le_event_start.le_type = LivenessEventType::START; + le_event_start.tick = buffer_info_starts[allocate].IntValue(); + le_events_timeline.push_back(le_event_start); + } + } + + for (const BaseExpr& call : pass_data_.call_order_) { + Map buffer_info_ends = + pass_data_.buffer_info_end_stmt_idx_[call]; + if (buffer_info_ends.find(allocate) != buffer_info_ends.end()) { + LivenessEvent le_event_end; + le_event_end.buffer_info = buffer_info; + le_event_end.le_type = LivenessEventType::END; + le_event_end.tick = buffer_info_ends[allocate].IntValue(); + le_events_timeline.push_back(le_event_end); + } + } + } + + // Sort the liveness events based on the chronological + // ordering. For events that are simultaneous, START event + // takes precedence. + std::sort(le_events_timeline.begin(), le_events_timeline.end(), + [](const LivenessEvent& lhs, const LivenessEvent& rhs) { + if (lhs.tick < rhs.tick) { + return true; + } else if (lhs.tick == rhs.tick && lhs.le_type == LivenessEventType::START && + rhs.le_type == LivenessEventType::END) { + return true; + } + return false; + }); + + // Traverse the liveness events using a open set to track what + // is live while updating the conflicts through out the linear traversal + int open_set_size = 0; + int max_open_set_size = 0; + std::unordered_set open_set; + for (const auto& le_event : le_events_timeline) { + if (le_event.le_type == LivenessEventType::START) { + for (const BufferInfo& open_buffer_info : open_set) { + open_buffer_info->conflicts.push_back(le_event.buffer_info); + if (le_event.buffer_info != open_buffer_info) { + le_event.buffer_info->conflicts.push_back(open_buffer_info); + } + } + open_set_size += le_event.buffer_info->size_bytes.IntValue(); + if (open_set_size > max_open_set_size) { + max_open_set_size = open_set_size; + } + open_set.insert(le_event.buffer_info); + } else { + open_set_size -= le_event.buffer_info->size_bytes.IntValue(); + open_set.erase(le_event.buffer_info); + } + } + + // All ConstantPoolInfo items should have conflicts with each other + // as they will be placed in RO segment and pre-initialized. To achieve this + // first, split buffers to vars (WorkspacePoolInfo items) and constants (ConstantPoolInfo items): + Array buffer_info_vars; + Array buffer_info_constants; + for (const auto& kv : this->pass_data_.buffer_info_map_) { + const auto& stmt = kv.second; + if (stmt->IsInstance()) { + buffer_info_constants.push_back(kv.first); + } else { + buffer_info_vars.push_back(kv.first); + } + } + ICHECK(pass_data_.buffer_info_map_.size() == + buffer_info_vars.size() + buffer_info_constants.size()) + << "missing value"; + + Map srch; + // Then intersect constants with each other, as all constants should exist at the same time: + for (const auto& buf : buffer_info_constants) { + srch.Set(buf, buf); + Array conflicts; + std::copy_if(buffer_info_constants.begin(), buffer_info_constants.end(), + std::back_inserter(conflicts), [buf](const auto& b) { return b != buf; }); + buf->conflicts.Assign(conflicts.begin(), conflicts.end()); + } + + // And third, remove all conflicts between constants and vars: + for (const auto& buf : buffer_info_vars) { + Array conflicts; + std::copy_if(buf->conflicts.begin(), buf->conflicts.end(), std::back_inserter(conflicts), + [&srch](const auto& c) { return srch.end() == srch.find(c); }); + buf->conflicts.Assign(conflicts.begin(), conflicts.end()); + } + return BufferInfoAnalysis(this->pass_data_.buffer_info_map_, max_open_set_size); +} + +tir::usmp::BufferInfoAnalysis ExtractBufferInfo(const relax::Function& main_func, + const IRModule& mod) { + return BufferInfoExtractor(mod)(main_func); +} + +TVM_REGISTER_GLOBAL("relax.analysis.extract_buffer_info") + .set_body_typed([](relax::Function main_func, IRModule mod) { + return (ExtractBufferInfo(main_func, mod)); + }); + +} // namespace tvm diff --git a/src/relax/usmp/transform/assign_pool_info.cc b/src/relax/usmp/transform/assign_pool_info.cc new file mode 100644 index 0000000000..36cbc6e11b --- /dev/null +++ b/src/relax/usmp/transform/assign_pool_info.cc @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include +#include + +#include "tvm/relax/attrs/memory.h" +#include "tvm/relax/expr_functor.h" + +namespace tvm { + +/*! \brief Assign PoolInfo objects to allocate that does not have any. + * The schedulers have the oppurtunity to assign PoolInfo objects to + * allocate nodes. However, each allocate node is expected to have + * at least one PoolInfo node assigned to it. If it was not the case, + * this Pass will assign all PoolInfo objects that the target could + * access.*/ + +namespace tir { +namespace usmp { + +class TIRPoolInfoAssigner : public StmtExprMutator { + public: + explicit TIRPoolInfoAssigner(PrimFunc func, const Map>& target_pool_infos, + const Map>& target_const_pool_infos) + : func_(std::move(func)), + target_pool_infos_(target_pool_infos), + target_const_pool_infos_(target_const_pool_infos) {} + + Stmt operator()(); + + private: + Stmt VisitStmt_(const AllocateNode* op) override; + Stmt VisitStmt_(const AllocateConstNode* op) override; + + PrimFunc func_; + Map> target_pool_infos_; + Map> target_const_pool_infos_; +}; + +Stmt TIRPoolInfoAssigner::operator()() { return this->VisitStmt(func_->body); } + +Stmt TIRPoolInfoAssigner::VisitStmt_(const AllocateNode* op) { + Optional tgt = func_->GetAttr(tvm::attr::kTarget).value(); + ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_; + Map annotations = Map(op->annotations); + if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) { + ICHECK(target_pool_infos_.count(tgt.value()->str()) > 0) + << "Target " << PrettyPrint(tgt) << " not found among " << PrettyPrint(target_pool_infos_); + annotations.Set(kPoolCandidatesAllocateAttr, target_pool_infos_[tgt.value()->str()]); + } + Stmt body = VisitStmt(op->body); + auto allocate = + Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body, annotations); + return std::move(allocate); +} + +Stmt TIRPoolInfoAssigner::VisitStmt_(const AllocateConstNode* op) { + if (!target_const_pool_infos_.size()) { + return StmtExprMutator::VisitStmt_(op); + } + Optional tgt = func_->GetAttr(tvm::attr::kTarget).value(); + ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_; + Map annotations = Map(op->annotations); + if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) { + annotations.Set(kPoolCandidatesAllocateAttr, target_const_pool_infos_[tgt.value()->str()]); + annotations.Set(kTargetPoolReadOnlyAccess, Integer(1)); + } + Stmt body = VisitStmt(op->body); + auto allocate_const = + AllocateConst(op->buffer_var, op->dtype, op->extents, op->data, body, annotations); + return std::move(allocate_const); +} + +} // namespace usmp +} // namespace tir + +namespace relax { +namespace usmp { + +class RelaxPoolInfoAssigner : public ExprMutator { + public: + explicit RelaxPoolInfoAssigner(Function func, + const Map>& target_pool_infos, + const Map>& target_const_pool_infos) + : func_(std::move(func)), + target_pool_infos_(target_pool_infos), + target_const_pool_infos_(target_const_pool_infos) {} + + Expr operator()(); + + private: + Expr VisitExpr_(const CallNode* op) override; + + Function func_; + Map> target_pool_infos_; + Map> target_const_pool_infos_; +}; + +Expr RelaxPoolInfoAssigner::operator()() { return this->VisitExpr(func_->body); } + +Expr RelaxPoolInfoAssigner::VisitExpr_(const CallNode* call) { + Expr expr = VisitExprPostOrder_(call); + call = expr.as(); + + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + if (call->op != alloc_tensor_op) { + return GetRef(call); + } + Optional tgt = func_->GetAttr(tvm::attr::kTarget).value(); + ICHECK(tgt) << "The following Func does not have a target attr: \n" << func_; + auto alloc_attrs = call->attrs.as(); + ICHECK(alloc_attrs != nullptr) << "must be AllocTensorAttrs"; + if (alloc_attrs->candidate_memory_pools.size() > 0) { + return GetRef(call); + } + ICHECK(target_pool_infos_.count(tgt.value()->str()) > 0) + << "Target " << PrettyPrint(tgt) << " not found among " << PrettyPrint(target_pool_infos_); + auto alloc_tensor_attr = make_object(); + alloc_tensor_attr->dtype = alloc_attrs->dtype; + alloc_tensor_attr->runtime_device_index = alloc_attrs->runtime_device_index; + alloc_tensor_attr->candidate_memory_pools = target_pool_infos_[tgt.value()->str()]; + auto allocate_call = + Call(call->op, call->args, Attrs(alloc_tensor_attr), call->type_args, call->span); + return std::move(allocate_call); +} + +} // namespace usmp +} // namespace relax + +class PoolInfoAssigner { + public: + explicit PoolInfoAssigner(const IRModule& module) { + auto main_func = + Downcast(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); + ICHECK(main_func.defined()) << "main function is not in the module"; + Optional target_host = main_func->GetAttr(tvm::attr::kTarget); + ICHECK(target_host) << "main function does not have a target attr"; + WorkspaceMemoryPools workspace_pools = + module->GetAttr(tvm::attr::kWorkspaceMemoryPools) + .value_or(WorkspaceMemoryPools({CreateDefaultWorkspaceMemoryPool(module)})); + // make default ConstantPoolInfo if no constant and no workspace pool infos supplied + ConstantMemoryPools constant_pools = + module->GetAttr(tvm::attr::kConstantMemoryPools) + .value_or( + module->GetAttr(tvm::attr::kWorkspaceMemoryPools).defined() + ? ConstantMemoryPools() + : ConstantMemoryPools({CreateDefaultConstantMemoryPool(module)})); + auto to_map = [](auto pool_infos) { + Map> pool_map; + for (const PoolInfo& pool_info : pool_infos) { + for (const auto& tgt : pool_info->targets) { + if (pool_map.find(tgt->str()) == pool_map.end()) { + pool_map.Set(tgt->str(), Array()); + } + Array pool_info_arr = pool_map[tgt->str()]; + pool_info_arr.push_back(pool_info); + pool_map.Set(tgt->str(), pool_info_arr); + } + } + return pool_map; + }; + + target_pool_infos_ = to_map(workspace_pools->pools); + if (constant_pools.defined()) { + target_const_pool_infos_ = to_map(constant_pools->pools); + } + mod_ = module->ShallowCopy(); + } + + IRModule operator()(); + + private: + IRModule mod_; + Map> target_pool_infos_; + Map> target_const_pool_infos_; + WorkspacePoolInfo CreateDefaultWorkspaceMemoryPool(const IRModule& module); + ConstantPoolInfo CreateDefaultConstantMemoryPool(const IRModule& module) { + auto p = CreateDefaultWorkspaceMemoryPool(module); + return ConstantPoolInfo( + "global_const_workspace", {p->targets}, {}, + PoolInfoProperties(kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth, + kUnknownWriteBandwidth, 0, 0, {p->target_burst_bytes}, Bool(true))); + } +}; + +WorkspacePoolInfo PoolInfoAssigner::CreateDefaultWorkspaceMemoryPool(const tvm::IRModule& module) { + VLOG(1) << "Creating default memory pool for:" << std::endl << PrettyPrint(module); + Map target_access; + auto main_func = Downcast(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); + Target target_host = main_func->GetAttr(tvm::attr::kTarget).value(); + for (const auto& kv : module->functions) { + BaseFunc func = kv.second; + Optional target = func->GetAttr(tvm::attr::kTarget); + target_access.Set(target.value_or(target_host), kTargetPoolReadWriteAccess); + } + Array targets; + for (const auto& kv : target_access) { + bool exist = false; + // Exclude targets with the same string representation + for (const auto& t : targets) { + if (t->str() == kv.first->str()) { + exist = true; + } + } + if (!exist) { + targets.push_back(kv.first); + } + } + return WorkspacePoolInfo( + "global_workspace", targets, + PoolInfoProperties(kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth, + kUnknownWriteBandwidth, 0, 0, {{target_host, 1}}, Bool(true))); +} + +IRModule PoolInfoAssigner::operator()() { + for (const auto& kv : mod_->functions) { + GlobalVar gv = kv.first; + if (kv.second->IsInstance()) { + using RelaxPoolInfoAssigner = relax::usmp::RelaxPoolInfoAssigner; + using Function = relax::Function; + auto func = runtime::Downcast(kv.second); + RelaxPoolInfoAssigner relax_pool_info_assigner = + RelaxPoolInfoAssigner(func, target_pool_infos_, target_const_pool_infos_); + relax::Expr body = relax_pool_info_assigner(); + Function new_relax_func = + Function(func->params, body, func->ret_type, func->ret_shape, func->attrs, func->span); + mod_->Update(gv, new_relax_func); + } else if (kv.second->IsInstance()) { + using TIRPoolInfoAssigner = tir::usmp::TIRPoolInfoAssigner; + using PrimFunc = tir::PrimFunc; + auto func = Downcast(kv.second); + TIRPoolInfoAssigner tir_pool_info_assigner = + TIRPoolInfoAssigner(func, target_pool_infos_, target_const_pool_infos_); + tir::Stmt body = tir_pool_info_assigner(); + PrimFunc new_prim_func = PrimFunc(func->params, body, func->ret_type, func->buffer_map, + func->preflattened_buffer_map, func->attrs); + mod_->Update(gv, new_prim_func); + } + } + return mod_; +} + +namespace transform { + +tvm::transform::Pass AssignPoolInfo() { + auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { + return PoolInfoAssigner(m)(); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "relax.usmp.AssignPoolInfo", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AssignPoolInfo").set_body_typed(AssignPoolInfo); + +} // namespace transform +} // namespace tvm diff --git a/src/relax/utils.cc b/src/relax/utils.cc new file mode 100644 index 0000000000..75a882de45 --- /dev/null +++ b/src/relax/utils.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Helper to implement bind params.*/ +class ExprBinder : public ExprMutator { + public: + explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) {} + + Expr VisitExpr_(const VarNode* op) final { + auto id = GetRef(op); + auto it = args_map_.find(id); + if (it != args_map_.end()) { + return (*it).second; + } else { + return ExprMutator::VisitExpr_(op); + } + } + + private: + const tvm::Map& args_map_; +}; + +/*! + * \brief Bind params on expr + * \param expr The expr where to bind params + * \param args_map The map from param var to the expr it binds to + * \return The result expr after bind params + */ +Expr Bind(const Expr& expr, const tvm::Map& args_map) { + if (const FunctionNode* func = expr.as()) { + Expr new_body = ExprBinder(args_map).VisitExpr(func->body); + Array new_params; + for (size_t i = 0; i < func->params.size(); ++i) { + if (!args_map.count(func->params[i])) { + new_params.push_back(func->params[i]); + } + } + if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { + return expr; + } + // The checked_type_ of the new function is deduced from the function body + // TODO(@relax-team): Should infer the shape from the body as well + return Function(new_params, new_body, Type(), RuntimeDepShape(), func->attrs); + } else { + return ExprBinder(args_map).VisitExpr(expr); + } +} + +bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) { + const DynTensorTypeNode* tt = ty.as(); + if (!tt) { + return false; + } + bool correct_dtype = tt->dtype.is_bool() || (permit_unknown_dtype && tt->dtype.is_void()); + bool correct_rank = tt->ndim == 0 || (permit_unknown_rank && tt->ndim == -1); + return correct_dtype && correct_rank; +} + +} // namespace relax +} // namespace tvm diff --git a/src/relay/analysis/graph_partitioner.cc b/src/relay/analysis/graph_partitioner.cc new file mode 100644 index 0000000000..861fd58d9e --- /dev/null +++ b/src/relay/analysis/graph_partitioner.cc @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./graph_partitioner.h" + +#include + +namespace tvm { +namespace relay { + +DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { + DominatorTree tree; + tree.nodes.resize(graph.post_dfs_order.size(), nullptr); + // reverse topo order + for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { + size_t index = i - 1; + tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); + } + return tree; +} + +DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs, + OpPatternKind* edge_pattern) { + while (lhs != rhs) { + if (lhs == nullptr) return nullptr; + if (rhs == nullptr) return nullptr; + if (lhs->depth < rhs->depth) { + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); + rhs = rhs->parent; + } else if (rhs->depth < lhs->depth) { + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); + lhs = lhs->parent; + } else { + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); + lhs = lhs->parent; + rhs = rhs->parent; + } + } + return lhs; +} + +DominatorTree::Node* DominatorTree::LeastCommonAncestor( + const LinkedList& input_nodes, OpPatternKind* edge_pattern) { + auto link = input_nodes.head; + if (link == nullptr) { + return nullptr; + } + auto get_node = [&](const IndexedForwardGraph::Edge& edge) { + size_t oindex = edge.node->index; + ICHECK_LT(oindex, nodes.size()); + Node* onode = nodes[oindex]; + ICHECK(onode != nullptr); + return onode; + }; + Node* parent = get_node(link->value); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + link = link->next; + for (; link != nullptr; link = link->next) { + parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + } + return parent; +} + +DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena, + IndexedForwardGraph::Node* gnode) { + Node* tnode = arena->make(); + tnode->gnode = gnode; + if (gnode->extern_ref) { + tnode->depth = 1; + tnode->parent = nullptr; + tnode->pattern = kOpaque; + } else { + // find the LCAs of all outputs. + OpPatternKind pattern = kElemWise; + Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); + tnode->depth = parent ? parent->depth + 1 : 1; + tnode->parent = parent; + tnode->pattern = pattern; + } + return tnode; +} + +std::vector GraphPartitioner::Partition( + const IndexedForwardGraph& graph) { + this->InitGroups(graph); + if (opt_level_ == 0) return std::move(groups_); + // get post dominator tree + auto post_dom_tree = DominatorTree::PostDom(arena_, graph); + // run fusion algorithm. + for (int phase = 0; phase < 3; ++phase) { + this->RunFuse(graph, post_dom_tree, phase); + } + return std::move(groups_); +} + +GraphPartitioner::Group* GraphPartitioner::Group::FindRoot() { + // fast path + if (this->parent == nullptr) return this; + // slow path with path compression. + Group* root = this; + while (root->parent != nullptr) { + root = root->parent; + } + for (Group* p = this; p != root;) { + Group* parent = p->parent; + p->parent = root; + p = parent; + } + return root; +} + +template +bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + F fcond) { + if (visited_.count(src)) return true; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + gnode = gnode->FindRoot(); + if (!fcond(gnode->pattern, src == sink)) return false; + if (src == sink) return true; + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + if (!CheckPath_(link->value.node, sink, fcond)) return false; + } + return true; +} + +template +bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + F fcond) { + ICHECK(!src->extern_ref); + visited_.clear(); + ICHECK(src != sink); + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + if (!CheckPath_(link->value.node, sink, fcond)) return false; + } + return true; +} + +OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { + if (lhs > relay::kBroadcast && rhs > relay::kBroadcast) { + LOG(FATAL) << "Cannot merge two complex group together"; + } + if (lhs > rhs) return lhs; + return rhs; +} + +void GraphPartitioner::MergeFromTo(Group* child, Group* parent) { + child = child->FindRoot(); + parent = parent->FindRoot(); + if (child == parent) return; + // update the number of nodes of the parent group + parent->num_nodes += child->num_nodes; + child->parent = parent; + // update anchor ref and pattern + if (child->anchor_ref != nullptr) { + ICHECK(parent->anchor_ref == nullptr); + parent->anchor_ref = child->anchor_ref; + parent->pattern = CombinePattern(child->pattern, parent->pattern); + } +} + +void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + Group* target) { + if (src == sink) return; + if (visited_.count(src)) return; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + // merge the current group to the parent if possible. + MergeFromTo(gnode, target); + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + CommitFuse_(link->value.node, sink, target); + } +} + +void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { + Group* target = groups_[sink->index]; + visited_.clear(); + ICHECK(src != sink); + CommitFuse_(src, sink, target); +} + +size_t GraphPartitioner::CountNodesUptoSink_(IndexedForwardGraph::Node* src, + IndexedForwardGraph::Node* sink) { + if (src == sink || visited_.count(src)) return 0; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + auto sum = gnode->num_nodes; + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + sum += CountNodesUptoSink_(link->value.node, sink); + } + return sum; +} + +size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, + IndexedForwardGraph::Node* dom_parent) { + Group* target = groups_[dom_parent->index]; + visited_.clear(); + ICHECK(child != dom_parent); + return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); +} + +void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { + groups_.resize(graph.post_dfs_order.size()); + for (size_t nid = 0; nid < groups_.size(); ++nid) { + const auto* graph_node = graph.post_dfs_order[nid]; + auto* group_node = arena_->make(); + group_node->pattern = graph_node->pattern; + group_node->root_ref = graph_node->ref; + // set anchor ref if necessary. + if (group_node->pattern == relay::kOutEWiseFusable) { + group_node->anchor_ref = graph_node->ref; + } + groups_[nid] = group_node; + } +} + +void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // + const DominatorTree& post_dom_tree, // + int phase) { + for (size_t nid = 0; nid < groups_.size(); ++nid) { + // the group of current node has been specified already. + auto* graph_node = graph.post_dfs_order[nid]; + auto* dom_node = post_dom_tree.nodes[nid]; + Group* group_node = groups_[nid]; + ICHECK(group_node != nullptr); + // no actions for opaque nodes + if (group_node->pattern == kOpaque) continue; + // no actions needed if the current node have no dominator + if (dom_node->parent == nullptr) continue; + ICHECK(!graph_node->extern_ref); + size_t dom_parent_gindex = dom_node->parent->gnode->index; + + // refuse the fusion if too many ops are going to be fused together + if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) + continue; + + if (phase == 2) { + // Fuse injective ops into intermediate tuples, if any + if (group_node->pattern > relay::kInjective) continue; + Group* dom_parent_group = groups_[dom_parent_gindex]; + Group* dom_root_group = dom_parent_group->FindRoot(); + // If dom node group has a tuple as its root, we do not fuse tuple fields into it + if (dom_root_group->pattern == relay::kTuple) continue; + if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) { + // Now we know the tuple has been fused into subsequent injective ops + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; + // dom_root_group can also be tuple, as in inception layers + // CheckPath is needed to avoid fusing two intermediate tuples + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + continue; + } + + // Skip if current node is already fused to the parent. + if (groups_[dom_parent_gindex] != nullptr && + group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { + continue; + } + // Do not fuse into tuple for now + if (groups_[dom_parent_gindex]->pattern == kTuple) continue; + // Try to fuse current node to its post-dominator. + if (group_node->pattern == kOutEWiseFusable) { + if (phase != 0) continue; + // Path for OutEWiseFusable: conv2d + // Check if the dominator relation is elemwise. + if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { + ICHECK(dom_node->parent->gnode != nullptr); + // The fuse can be executed if all the intermediate ops are still broadcast. + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + } else if (group_node->pattern <= kBroadcast) { + // Pre-condition: can only be fused to parent which is injective or reduction. + if (dom_node->parent != nullptr && + (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { + // Check if all the intermediate ops are still broadcast. + // The final terminal node can already be fused to a OutEWiseFusable group. + auto fcond = [](OpPatternKind kind, bool is_sink) { + if (!is_sink) { + // Elemwise, broadcast, and injective ops on the parallel branches + // are allowed be fused to the elemwise/broadcast anchor. + return kind <= kInjective; + } else { + return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || + kind == kOutEWiseFusable); + } + }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { + // defer injective fusion to second phase. + // so conv2d always finishes fusing. + if (phase != 1) continue; + // Check if all path are injective. + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } else { + // do nothing. + ICHECK(group_node->pattern == kCommReduce); + } + } +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/analysis/graph_partitioner.h b/src/relay/analysis/graph_partitioner.h new file mode 100644 index 0000000000..17835633b5 --- /dev/null +++ b/src/relay/analysis/graph_partitioner.h @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/graph_partitioner.h + * \brief The helper function for op fusion. + */ + +#include + +#include "../../support/arena.h" + +namespace tvm { +namespace relay { + +using support::LinkedList; +using support::LinkNode; + +/*! + * \brief Indexed data flow graph in forward direction. + * This is a temporary data structure used for operator fusion analysis. + * + * This data structure only captures the dataflow fragment and + * could ignore blocks like let by simply ordering each dataflow block + * and mark the output node as extern_ref; + */ +class IndexedForwardGraph { + public: + struct Node; + /*! + * The forward edge in the dataflow graph. + */ + struct Edge { + /*! \brief The corresponding node */ + Node* node{nullptr}; + /*! \brief The respective pattern of this op */ + OpPatternKind pattern{kOpaque}; + }; + /*! \brief A node in the graph. */ + struct Node { + /*! \brief weak reference to the corresponding edge. */ + const tvm::Object* ref{nullptr}; + /*! \brief The index of the node in topological order. */ + size_t index{0}; + /*! \brief Whether this node is referenced by external source */ + bool extern_ref{false}; + /*! \brief The general pattern in the node */ + OpPatternKind pattern{kOpaque}; + /*! \brief The outputs of the node. */ + LinkedList outputs; + }; + /*! \brief The node map that maps node to graph */ + std::unordered_map node_map; + /*! \brief All the nodes in post DFS order */ + std::vector post_dfs_order; + + /*! \brief Dump the graph into string. */ + void DebugDump() { + std::ostringstream os; + for (size_t i = 0; i < post_dfs_order.size(); ++i) { + Node* node = post_dfs_order[i]; + ObjectRef object = GetRef(node->ref); + os << "node[" << i << "], " << object->GetTypeKey() << " " << node->extern_ref << " " + << node->pattern << " outputs=["; + for (auto* link = node->outputs.head; link != nullptr; link = link->next) { + os << link->value.node->index << " (" << link->value.pattern << "), "; + } + os << "]\n"; + } + LOG(INFO) << "\n" << os.str(); + } +}; + +/*! + * \brief Dominator tree that represent domination or + * post domination relation of the node. + */ +class DominatorTree { + public: + /*! + * \brief A node in the dominator tree. + */ + struct Node { + /*! \brief The node in the tree */ + IndexedForwardGraph::Node* gnode{nullptr}; + /*! \brief parent of the tree */ + Node* parent{nullptr}; + /*! \brief current depth*/ + int depth{0}; + /*! \brief aggregated pattern to parent */ + OpPatternKind pattern{kOpaque}; + }; + // index -> node. + std::vector nodes; + /*! + * \brief compute a post dominator relation for a given dataflow graph. + * \param arena The arena used for node allocation. + * \param graph The graph to be analyzed. + * \return The dominator tree of the graph. + * \note This algorithm makes use of the fact that graph is DAG, + * and runs a single pass algorithm via LCA (Least Common Ancestor) + */ + static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); + + private: + // Combine pattern together. + inline static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { + if (lhs > rhs) return lhs; + return rhs; + } + /*! + * \brief Find the least common ancestor of the two nodes. + * \param lhs The left node. + * \param rhs The right node. + * \param edge_pattern + * The combined edge pattern across all the parents. + * \return The least common ancestor of the two. + */ + static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern); + /*! + * \brief Find the least common ancestor of a list of nodes. + * \param nodes the nodes. + * \param edge_pattern + * The combined edge pattern across all the parents. + * \return The least common ancestor of all nodes. + */ + Node* LeastCommonAncestor(const LinkedList& input_nodes, + OpPatternKind* edge_pattern); + + /*! + * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. + * \param arena The Arena. + * \param gnode An IndexedForwardGraph Node. + * \return The DominatorTree Node. + */ + Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode); +}; + +class GraphPartitioner { + public: + explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) + : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} + /*! + * \brief Group as a union find data structure. + */ + struct Group { + /*! \brief The parent in the union find data structure. */ + Group* parent{nullptr}; + /*! \brief The pattern of the group */ + OpPatternKind pattern; + /*! \brief reference to the root node. */ + const tvm::Object* root_ref{nullptr}; + /*! + * \brief Reference to the anchor node, + * this field is not nullptr only if pattern is kOutEWiseFusable. + */ + const tvm::Object* anchor_ref{nullptr}; + /*! + * \brief The number of nodes belonging to this group + */ + uint32_t num_nodes{1}; + + /*! + * \brief Find the group root, perform path compression + * \return The root type node. + */ + Group* FindRoot(); + }; + /*! + * \brief Partition a graph. + * \return group assignments of each node. + */ + std::vector Partition(const IndexedForwardGraph& graph); + + private: + /*! \brief The internal arena for temporary space. */ + support::Arena* arena_; + /*! \brief optimization level for fuse operation. */ + int opt_level_; + /*! \brief The maximum number of operations in one fused function */ + size_t max_fuse_depth_; + /*! \brief The internal groups. */ + std::vector groups_; + /*! \brief internal field used for deduplication */ + std::unordered_set visited_; + // Internal implementation of CheckPath + template + bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); + + /*! + * \brief Check all the node and edge pattern + * between src and sink satisfies fcond. + * + * src is not checked. + * + * \param src The source node. + * \param sink The termination node. + * \param fcond The condition to be checked. + * \tparam F the condition function, with signature + * \note sink must be a post-dominator of src. + */ + template + bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); + + /*! + * \brief Merge the child group to the parent. + * \param child The child group. + * \param parent The parent group. + */ + void MergeFromTo(Group* child, Group* parent); + + // Internal implementation of CommitFuse + void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target); + + /*! + * \brief Commit fusion operation. + * \param src The source node. + * \param sink The termination node. + * \note sink must be a post-dominator of src. + */ + void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); + + size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); + + // Count the number of nodes in a fused subgraph if child is additionally fused. + // dom_parent is already known to be a part of the subgraph. + // For a diamond structure, there can be multiple paths connecting child and dom_parent. + // All intermediate nodes between child and dom_parent are taken into account. + // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot() + // is important for correct calculation. + size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, + IndexedForwardGraph::Node* dom_parent); + + // Initialize the groups. + void InitGroups(const IndexedForwardGraph& graph); + + // execute the fusion algorithm. + void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase); +}; + +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 51bcab527d..04aad09fd3 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -416,7 +416,7 @@ Optional DefaultTIRConverterImpl(const Array& args, return NullOpt; } } - PrimFunc func = te::CreatePrimFuncWithConstants(args, constants); + PrimFunc func = te::CreatePrimFuncWithConstants(args, constants, {}); bool dynamic_loop_extent = false; tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void { if (const auto* loop = obj.as()) { @@ -443,6 +443,13 @@ TVM_REGISTER_GLOBAL("relay.backend.tir_converter.allow_extern") return DefaultTIRConverterImpl(args, constants, true); }); +TVM_REGISTER_GLOBAL("relay.backend.GetPassPrefixSeq") + .set_body_typed([](bool is_homogeneous, bool is_vm) { + auto pass_seqs = GetPassPrefix(is_homogeneous, is_vm); + transform::Sequential seq(pass_seqs); + return seq; + }); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 5f7b8747a7..b495687233 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -39,6 +39,8 @@ Id::Id(String name_hint) { data_ = std::move(n); } +TVM_REGISTER_GLOBAL("relay.ir.Id").set_body_typed([](String name_hint) { return Id(name_hint); }); + TVM_REGISTER_GLOBAL("ir.NodeSetSpan").set_body_typed([](ObjectRef node_ref, Span sp) { if (auto* rn = node_ref.as()) { rn->span = sp; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 5c85b3b29d..3c53a3a1f2 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -22,6 +22,7 @@ * \brief The expression AST nodes of Relay. */ #include +#include #include #include @@ -367,8 +368,8 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay.ir.If") - .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { - return If(cond, true_branch, false_branch); + .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { + return If(cond, true_branch, false_branch, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index fc1f3a1507..dd31a1f736 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -154,8 +154,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return FunctionPass(pass_func, pass_info); } diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index dac5dc69ea..433ad1547b 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -25,13 +25,13 @@ * Fuse necessary ops into a single one. */ #include -#include #include #include #include #include #include "../../support/arena.h" +#include "../analysis/graph_partitioner.h" #include "../op/annotation/annotation.h" #include "./pass_utils.h" #include "./pattern_utils.h" @@ -86,78 +86,16 @@ constexpr uint32_t kMaxFusedOps = 256; static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion"); TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer); -TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.link_params", Bool); - -/*! - * \brief Indexed data flow graph in forward direction. - * This is a temporary data structure used for operator fusion analysis. - * - * This data structure only captures the dataflow fragment and - * could ignore blocks like let by simply ordering each dataflow block - * and mark the output node as extern_ref; - */ -class IndexedForwardGraph { - public: - struct Node; - /*! - * The forward edge in the dataflow graph. - */ - struct Edge { - /*! \brief The corresponding node */ - Node* node{nullptr}; - /*! \brief The respective pattern of this op */ - OpPatternKind pattern{kOpaque}; - }; - /*! \brief A node in the graph. */ - struct Node { - /*! \brief weak reference to the corresponding edge. */ - const tvm::Object* ref{nullptr}; - /*! \brief The index of the node in topological order. */ - size_t index{0}; - /*! \brief Whether this node is referenced by external source */ - bool extern_ref{false}; - /*! \brief The general pattern in the node */ - OpPatternKind pattern{kOpaque}; - /*! \brief The outputs of the node. */ - LinkedList outputs; - }; - /*! \brief The node map that maps node to graph */ - std::unordered_map node_map; - /*! \brief All the nodes in post DFS order */ - std::vector post_dfs_order; - - /*! \brief Dump the graph into string. */ - void DebugDump() { - std::ostringstream os; - for (size_t i = 0; i < post_dfs_order.size(); ++i) { - Node* node = post_dfs_order[i]; - os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; - for (auto* link = node->outputs.head; link != nullptr; link = link->next) { - os << link->value.node->index << ", "; - } - os << "]\n"; - } - LOG(INFO) << os.str(); - } - /*! - * \brief create a indexed forward graph. - * \param arena The arena used for data allocation. - * \param body The body of the expression to create a graph. - */ - static IndexedForwardGraph Create(support::Arena* arena, const Expr& body); - - private: - class Creator; -}; // Creator of post dominator tree of the dataflow -class IndexedForwardGraph::Creator : private ExprVisitor { +class IndexedForwardGraphCreator : private ExprVisitor { public: - explicit Creator(support::Arena* arena) : arena_(arena) {} + explicit IndexedForwardGraphCreator(support::Arena* arena) : arena_(arena) {} IndexedForwardGraph Prepare(const Expr& body) { this->Update(body, nullptr, kOpaque); this->VisitExpr(body); + // graph_.DebugDump(); return std::move(graph_); } @@ -213,7 +151,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const ConstantNode* op) final { this->AddNode(op); - Node* node = graph_.node_map.at(op); + IndexedForwardGraph::Node* node = graph_.node_map.at(op); DataType dtype = DataType(op->data->dtype); // This rule must be consistent with code generator. bool is_simple_const = @@ -230,7 +168,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const CallNode* call) final { ICHECK(graph_.node_map.count(call)); - Node* node = graph_.node_map.at(call); + IndexedForwardGraph::Node* node = graph_.node_map.at(call); static auto fpattern = Op::GetAttrMap("TOpPattern"); // Now we set the pattern of this call. // @@ -274,7 +212,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const TupleNode* op) final { ICHECK(graph_.node_map.count(op)); - Node* tuple_node = graph_.node_map.at(op); + IndexedForwardGraph::Node* tuple_node = graph_.node_map.at(op); tuple_node->pattern = kTuple; for (const Expr& field : op->fields) { if (field->checked_type().as()) { @@ -306,7 +244,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->Update(op->tuple, nullptr, kOpaque); } else { ICHECK(graph_.node_map.count(op)); - Node* node = graph_.node_map.at(op); + IndexedForwardGraph::Node* node = graph_.node_map.at(op); node->pattern = kInjective; this->Update(op->tuple, node, kInjective); } @@ -372,460 +310,16 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } }; -IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) { - return Creator(arena).Prepare(body); -} - -/*! - * \brief Dominator tree that represent domination or - * post domination relation of the node. - */ -class DominatorTree { - public: - /*! - * \brief A node in the dominator tree. - */ - struct Node { - /*! \brief The node in the tree */ - IndexedForwardGraph::Node* gnode{nullptr}; - /*! \brief parent of the tree */ - Node* parent{nullptr}; - /*! \brief current depth*/ - int depth{0}; - /*! \brief aggregated pattern to parent */ - OpPatternKind pattern{kOpaque}; - }; - // index -> node. - std::vector nodes; - /*! - * \brief compute a post dominator relation for a given dataflow graph. - * \param arena The arena used for node allocation. - * \param graph The graph to be analyzed. - * \return The dominator tree of the graph. - * \note This algorithm makes use of the fact that graph is DAG, - * and runs a single pass algorithm via LCA (Least Common Ancestor) - */ - static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); - - private: - // Combine pattern together. - static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { - if (lhs > rhs) return lhs; - return rhs; - } - /*! - * \brief Find the least common ancestor of the two nodes. - * \param lhs The left node. - * \param rhs The right node. - * \param edge_pattern - * The combined edge pattern across all the parents. - * \return The least common ancestor of the two. - */ - static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) { - while (lhs != rhs) { - if (lhs == nullptr) return nullptr; - if (rhs == nullptr) return nullptr; - if (lhs->depth < rhs->depth) { - edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); - rhs = rhs->parent; - } else if (rhs->depth < lhs->depth) { - edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); - lhs = lhs->parent; - } else { - edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); - edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); - lhs = lhs->parent; - rhs = rhs->parent; - } - } - return lhs; - } - /*! - * \brief Find the least common ancestor of a list of nodes. - * \param nodes the nodes. - * \param edge_pattern - * The combined edge pattern across all the parents. - * \return The least common ancestor of all nodes. - */ - Node* LeastCommonAncestor(const LinkedList& input_nodes, - OpPatternKind* edge_pattern) { - auto link = input_nodes.head; - if (link == nullptr) { - return nullptr; - } - auto get_node = [&](const IndexedForwardGraph::Edge& edge) { - size_t oindex = edge.node->index; - ICHECK_LT(oindex, nodes.size()); - Node* onode = nodes[oindex]; - ICHECK(onode != nullptr); - return onode; - }; - Node* parent = get_node(link->value); - *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); - link = link->next; - for (; link != nullptr; link = link->next) { - parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); - *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); - } - return parent; - } - /*! - * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. - * \param arena The Arena. - * \param gnode An IndexedForwardGraph Node. - * \return The DominatorTree Node. - */ - Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode) { - Node* tnode = arena->make(); - tnode->gnode = gnode; - if (gnode->extern_ref) { - tnode->depth = 1; - tnode->parent = nullptr; - tnode->pattern = kOpaque; - } else { - // find the LCAs of all outputs. - OpPatternKind pattern = kElemWise; - Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); - tnode->depth = parent ? parent->depth + 1 : 1; - tnode->parent = parent; - tnode->pattern = pattern; - } - return tnode; - } -}; - -DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { - DominatorTree tree; - tree.nodes.resize(graph.post_dfs_order.size(), nullptr); - // reverse topo order - for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { - size_t index = i - 1; - tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); - } - return tree; -} - /*! * \brief A partition of the graph marked by union find data structure. */ -class GraphPartitioner { - public: - explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) - : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} - /*! - * \brief Group as a union find data structure. - */ - struct Group { - /*! \brief The parent in the union find data structure. */ - Group* parent{nullptr}; - /*! \brief The pattern of the group */ - OpPatternKind pattern; - /*! \brief reference to the root node. */ - const tvm::Object* root_ref{nullptr}; - /*! - * \brief Reference to the anchor node, - * this field is not nullptr only if pattern is kOutEWiseFusable. - */ - const tvm::Object* anchor_ref{nullptr}; - /*! - * \brief Find the group root, perform path compression - * \return The root type node. - */ - Group* FindRoot() { - // fast path - if (this->parent == nullptr) return this; - // slow path with path compression. - Group* root = this; - while (root->parent != nullptr) { - root = root->parent; - } - for (Group* p = this; p != root;) { - Group* parent = p->parent; - p->parent = root; - p = parent; - } - return root; - } - - /*! - * \brief The number of nodes belonging to this group - */ - uint32_t num_nodes{1}; - }; - /*! - * \brief Partition a graph. - * \return group assignments of each node. - */ - std::vector Partition(const IndexedForwardGraph& graph); - - private: - /*! \brief The internal arena for temporary space. */ - support::Arena* arena_; - /*! \brief optimization level for fuse operation. */ - int opt_level_; - /*! \brief The maximum number of operations in one fused function */ - size_t max_fuse_depth_; - /*! \brief The internal groups. */ - std::vector groups_; - /*! \brief internal field used for deduplication */ - std::unordered_set visited_; - // Internal implelementation of CheckPath - template - bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { - if (visited_.count(src)) return true; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - gnode = gnode->FindRoot(); - if (!fcond(gnode->pattern, src == sink)) return false; - if (src == sink) return true; - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - if (!CheckPath_(link->value.node, sink, fcond)) return false; - } - return true; - } - /*! - * \brief Check all the node and edge pattern - * between src and sink satisfies fcond. - * - * src is not checked. - * - * \param src The source node. - * \param sink The termination node. - * \param fcond The condition to be checked. - * \tparam F the condition function, with signature - * \note sink must be a post-dominator of src. - */ - template - bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { - ICHECK(!src->extern_ref); - visited_.clear(); - ICHECK(src != sink); - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - if (!CheckPath_(link->value.node, sink, fcond)) return false; - } - return true; - } - // Combine two patterns together. - static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { - if (lhs > kBroadcast && rhs > kBroadcast) { - LOG(FATAL) << "Cannot merge two complex group together"; - } - if (lhs > rhs) return lhs; - return rhs; - } - /*! - * \brief Merge the child group to the parent. - * \param child The child group. - * \param parent The parent group. - */ - void MergeFromTo(Group* child, Group* parent) { - child = child->FindRoot(); - parent = parent->FindRoot(); - if (child == parent) return; - // update the number of nodes of the parent group - parent->num_nodes += child->num_nodes; - child->parent = parent; - // update anchor ref and pattern - if (child->anchor_ref != nullptr) { - ICHECK(parent->anchor_ref == nullptr); - parent->anchor_ref = child->anchor_ref; - parent->pattern = CombinePattern(child->pattern, parent->pattern); - } - } - // Internal implelementation of CommitFuse - void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target) { - if (src == sink) return; - if (visited_.count(src)) return; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - // merge the current group to the parent if possible. - MergeFromTo(gnode, target); - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - CommitFuse_(link->value.node, sink, target); - } - } - /*! - * \brief Commit fusion operation. - * \param src The source node. - * \param sink The termination node. - * \note sink must be a post-dominator of src. - */ - void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { - Group* target = groups_[sink->index]; - visited_.clear(); - ICHECK(src != sink); - CommitFuse_(src, sink, target); - } - - size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { - if (src == sink || visited_.count(src)) return 0; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - auto sum = gnode->num_nodes; - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - sum += CountNodesUptoSink_(link->value.node, sink); - } - return sum; - } - - // Count the number of nodes in a fused subgraph if child is additionaly fused. - // dom_parent is already known to be a part of the subgraph. - // For a diamond structure, there can be multiple paths connecting child and dom_parent. - // All intermediate nodes between child and dom_parent are taken into account. - // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot() - // is important for correct calculation. - size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, - IndexedForwardGraph::Node* dom_parent) { - Group* target = groups_[dom_parent->index]; - visited_.clear(); - ICHECK(child != dom_parent); - return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); - } - - // Initialize the groups. - void InitGroups(const IndexedForwardGraph& graph) { - groups_.resize(graph.post_dfs_order.size()); - for (size_t nid = 0; nid < groups_.size(); ++nid) { - const auto* graph_node = graph.post_dfs_order[nid]; - auto* group_node = arena_->make(); - group_node->pattern = graph_node->pattern; - group_node->root_ref = graph_node->ref; - // set anchor ref if necessary. - if (group_node->pattern == kOutEWiseFusable) { - group_node->anchor_ref = graph_node->ref; - } - groups_[nid] = group_node; - } - } - - // execute the fusion algorithm. - void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) { - for (size_t nid = 0; nid < groups_.size(); ++nid) { - // the group of current node has been specified already. - auto* graph_node = graph.post_dfs_order[nid]; - auto* dom_node = post_dom_tree.nodes[nid]; - Group* group_node = groups_[nid]; - ICHECK(group_node != nullptr); - // no actions for opaque nodes - if (group_node->pattern == kOpaque) continue; - // no actions needed if the current node have no dominator - if (dom_node->parent == nullptr) continue; - ICHECK(!graph_node->extern_ref); - size_t dom_parent_gindex = dom_node->parent->gnode->index; - - // refuse the fusion if too many ops are going to be fused together - if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) - continue; - - if (phase == 2) { - // Fuse injective ops into intermediate tuples, if any - if (group_node->pattern > kInjective) continue; - Group* dom_parent_group = groups_[dom_parent_gindex]; - Group* dom_root_group = dom_parent_group->FindRoot(); - // If dom node group has a tuple as its root, we do not fuse tuple fields into it - if (dom_root_group->pattern == kTuple) continue; - if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { - // Now we know the tuple has been fused into subsequent injective ops - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; - // dom_root_group can also be tuple, as in inception layers - // CheckPath is needed to avoid fusing two intermediate tuples - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - continue; - } - - // Skip if current node is already fused to the parent. - if (groups_[dom_parent_gindex] != nullptr && - group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { - continue; - } - // Do not fuse into tuple for now - if (groups_[dom_parent_gindex]->pattern == kTuple) continue; - // Try to fuse current node to its post-dominator. - if (group_node->pattern == kOutEWiseFusable) { - if (phase != 0) continue; - // Path for OutEWiseFusable: conv2d - // Check if the dominator relation is elemwise. - if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { - ICHECK(dom_node->parent->gnode != nullptr); - // The fuse can be executed if all the intermediate ops are still broadcast. - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - } else if (group_node->pattern <= kBroadcast) { - // Pre-condition: can only be fused to parent which is injective or reduction. - if (dom_node->parent != nullptr && - (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { - // Check if all the intermediate ops are still broadcast. - // The final terminal node can already be fused to a OutEWiseFusable group. - auto fcond = [](OpPatternKind kind, bool is_sink) { - if (!is_sink) { - // Elemwise, broadcast, and injective ops on the parallel branches - // are allowed be fused to the elemwise/broadcast anchor. - return kind <= kInjective; - } else { - return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || - kind == kOutEWiseFusable); - } - }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { - // defer injective fusion to second phase. - // so conv2d always finishes fusing. - if (phase != 1) continue; - // Check if all path are injective. - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } else { - // do nothing. - ICHECK(group_node->pattern == kCommReduce); - } - } - } -}; - -std::vector GraphPartitioner::Partition( - const IndexedForwardGraph& graph) { - this->InitGroups(graph); - if (opt_level_ == 0) return std::move(groups_); - // get post dominator tree - auto post_dom_tree = DominatorTree::PostDom(arena_, graph); - // run fusion algorithm. - for (int phase = 0; phase < 3; ++phase) { - this->RunFuse(graph, post_dom_tree, phase); - } - return std::move(groups_); -} class FuseMutator : private MixedModeMutator { public: - FuseMutator(int fuse_opt_level, size_t max_fuse_depth, bool link_params) - : fuse_opt_level_(fuse_opt_level), - max_fuse_depth_(max_fuse_depth), - link_params_(link_params) {} - // Run the transform - Expr Transform(const Expr& body) { - return Transform(body, fuse_opt_level_, max_fuse_depth_, link_params_); - } - - protected: - // Run the transform - Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) { + Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) { // setup the group map. - auto graph = IndexedForwardGraph::Create(&arena_, body); + auto graph = IndexedForwardGraphCreator(&arena_).Prepare(body); auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph); for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { ICHECK(graph.post_dfs_order[nid]->ref != nullptr); @@ -837,10 +331,6 @@ class FuseMutator : private MixedModeMutator { } private: - int fuse_opt_level_; - size_t max_fuse_depth_; - bool link_params_; - using MixedModeMutator::VisitExpr_; /*! \brief Temporary information from each group. */ @@ -1011,12 +501,8 @@ class FuseMutator : private MixedModeMutator { auto type = arg->checked_type(); Expr new_arg = this->Mutate(arg); if (current_group != arg_group) { - if (!link_params_ || new_arg.as() == nullptr) { - Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type); - new_args.push_back(param); - } else { - new_args.push_back(new_arg); - } + Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type); + new_args.push_back(param); } else { new_args.push_back(new_arg); } @@ -1038,9 +524,8 @@ class FuseMutator : private MixedModeMutator { } }; -Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, bool link_params, - const IRModule& module) { - return FuseMutator(fuse_opt_level, max_fuse_depth, link_params).Transform(expr); +Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, const IRModule& module) { + return FuseMutator().Transform(expr, fuse_opt_level, max_fuse_depth); } namespace transform { @@ -1048,17 +533,9 @@ namespace transform { Pass FuseOps(int fuse_opt_level) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - bool link_params = false; - Executor executor = - m->GetAttr(tvm::attr::kExecutor).value_or(NullValue()); - link_params = executor.defined() - ? executor->attrs.GetAttr("link-params").value_or(Bool(link_params)) - : link_params; - link_params = pc->GetConfig("relay.FuseOps.link_params", Bool(link_params)).value(); int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps)); - return Downcast( - FuseOps(f, opt_level, max_fuse_depth.value().IntValue(), link_params, m)); + return Downcast(FuseOps(f, opt_level, max_fuse_depth.value().IntValue(), m)); }; return CreateFunctionPass(pass_func, 0, "FuseOps", {"InferType"}); } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index d2eb48073f..a152bbe9c3 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -950,7 +950,7 @@ TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const E }); Pass InferType() { - auto pass_info = PassInfo(0, "InferType", {}); + auto pass_info = PassInfo(0, "InferType", {}, /* trace */ false); return tvm::transform::CreateModulePass( [=](IRModule mod, const PassContext& pass_ctx) { // Execute the pass function and return a new module. diff --git a/src/runtime/hexagon/android/sim/hexagon_device_sim.cc b/src/runtime/hexagon/android/sim/hexagon_device_sim.cc new file mode 100644 index 0000000000..05559a1d1a --- /dev/null +++ b/src/runtime/hexagon/android/sim/hexagon_device_sim.cc @@ -0,0 +1,1468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../hexagon_device.h" +#include "HexagonWrapper.h" +#include "hexagon_sim_proto.h" + +namespace tvm { +namespace runtime { +namespace hexagon { + +static_assert(sizeof(HEX_VA_t) == sizeof(uint32_t), "Hexagon VA must be uint32"); + +template +struct unalign { + using type = struct { T value; } __attribute__((aligned(1), packed)); +}; + +template +struct uint { + using type = void; +}; + +template <> +struct uint<8> { + using type = uint64_t; +}; +template <> +struct uint<4> { + using type = uint32_t; +}; +template <> +struct uint<2> { + using type = uint16_t; +}; +template <> +struct uint<1> { + using type = uint8_t; +}; + +using string_list = std::deque; + +namespace detail { + +template +std::unique_ptr make_unique(Args... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} +template +std::unique_ptr make_unique(size_t size) { + using U = typename std::remove_extent::type; + return std::unique_ptr(new U[size]()); +} + +// An "Optional" class, originally a replacement for llvm::Optional, then an +// extension of dmlc::optional to make it compatible with C++17's std::optional. +template +struct Optional : public dmlc::optional { + using dmlc::optional::optional; + using dmlc::optional::operator=; + Optional(const T& val) : dmlc::optional(val) {} // NOLINT(*) + + T* operator->() { return &this->operator*(); } + const T* operator->() const { return &this->operator*(); } +}; + +// Converter class to translate vector to char**. This relieves the +// user from memory reallocation and copying. +struct non_const_str { + non_const_str() {} + explicit non_const_str(const std::string& str) : non_const_str(std::vector{str}) {} + explicit non_const_str(const std::vector& vec) { + for (const std::string& s : vec) { + auto c = detail::make_unique(s.size() + 1); + std::strncpy(c.get(), s.c_str(), s.size() + 1); + storage_.push_back(std::move(c)); + pointers_.push_back(storage_.back().get()); + } + } + non_const_str(non_const_str&& ncs) { *this = std::move(ncs); } + non_const_str& operator=(non_const_str&& ncs) { + if (this != &ncs) { + for (auto& s : ncs.storage_) storage_.push_back(std::move(s)); + for (auto& s : storage_) pointers_.push_back(s.get()); + } + return *this; + } + size_t size() const { return pointers_.size(); } + operator char*() { + ICHECK_EQ(pointers_.size(), 1); + return pointers_[0]; + } + operator char**() { return pointers_.data(); } + + private: + std::vector pointers_; + std::vector> storage_; +}; + +using MaybeString = Optional; + +MaybeString front(const string_list& deq) { + return !deq.empty() ? MaybeString(deq.front()) : MaybeString(); +} + +MaybeString pop_front(string_list& deq) { // NOLINT(*) + if (deq.empty()) return MaybeString(); + std::string f = deq.front(); + deq.pop_front(); + return MaybeString(f); +} + +Optional to_int(const MaybeString& str) { + auto none = Optional(); + if (str.has_value()) { + try { + size_t pos; + int64_t val = std::stoll(*str, &pos, 0); + return pos == str->size() ? Optional(val) : none; + } catch (std::invalid_argument) { + } + } + return none; +} + +Optional to_uint(const MaybeString& str) { + auto none = Optional(); + if (str.has_value()) { + try { + size_t pos; + uint64_t val = std::stoull(*str, &pos, 0); + return pos == str->size() ? Optional(val) : none; + } catch (std::invalid_argument) { + } + } + return none; +} + +Optional to_float(const MaybeString& str) { + auto none = Optional(); + if (str.has_value()) { + try { + size_t pos; + float val = std::stof(*str, &pos); + return pos == str->size() ? Optional(val) : none; + } catch (std::invalid_argument) { + } + } + return none; +} + +Optional to_bool(const MaybeString& str) { + auto none = Optional(); + if (auto num = to_int(str)) { + if (*num == 0) return false; + if (*num == 1) return true; + return none; + } + if (str) { + if (*str == "true" || *str == "TRUE") return true; + if (*str == "false" || *str == "FALSE") return false; + } + return none; +} + +template +using MaybeRange = Optional>; + +template Parse(const MaybeString&)> +MaybeRange to_range(const MaybeString& str) { + auto none = MaybeRange(); + if (str && !str->empty()) { + auto n = str->find('-', 1); + if (n != std::string::npos) { + auto begin = Parse(str->substr(0, n)); + auto end = Parse(str->substr(n + 1, str->size() - n - 1)); + if (begin && end) { + return std::make_pair(*begin, *end); + } + } + } + return none; +} + +// Replacement for llvm::StringSwitch. +template +class StringSwitch { + public: + explicit StringSwitch(const std::string& key) : key(key) {} + operator T() const { + auto f = map.find(key); + if (f != map.end()) { + return f->second; + } + ICHECK(static_cast(def_val)) << "default value not set"; + return *def_val; + } + StringSwitch& Case(const std::string& key, T val) { + map.insert(std::make_pair(key, val)); + return *this; + } + StringSwitch& Default(T val) { + ICHECK(!static_cast(def_val)) << "default value already set"; + def_val = val; + return *this; + } + + private: + const std::string key; + std::map map; + Optional def_val; +}; + +// Replacement for llvm::sys::fs::access with AccessMode = Execute. +bool FileExists(const std::string& file) { return access(file.c_str(), X_OK) == 0; } + +// Replacement for llvm::sys::Process::FindInEnvPath. +MaybeString FindInEnvPath(const std::string& env_var, const std::string& file) { + auto none = MaybeString(); + if (file.empty() || file[0] == '/') { + return none; + } + + const char* e = getenv(env_var.c_str()); + std::string env_val = e != nullptr ? std::string(e) : std::string(); + + std::vector paths; + // Split the environment variable into individual paths. + size_t first = 0, env_size = env_val.size(); + for (size_t last = 0; last != env_size; ++last) { + if (env_val[last] == ':') { + if (last > first) { + paths.emplace_back(env_val, first, last - first); + } + first = last + 1; + } + } + if (first < env_size) { + paths.emplace_back(env_val, first, env_size - first); + } + + // Search for the file. + for (const std::string& dir : paths) { + std::string full = dir + '/' + file; + if (FileExists(full)) { + return full; + } + } + return none; +} +} // namespace detail + +class HexagonSimulator final : public tvm::runtime::hexagon::Device { + public: + explicit HexagonSimulator(bool enable_queuing); + ~HexagonSimulator() final {} + void* Alloc(unsigned size, unsigned align) final; + void Free(void* ptr) final; + void* AllocVtcm(unsigned size, unsigned align) final; + void FreeVtcm(void* ptr) final; + void CopyDeviceToDevice(void* dst, const void* src, unsigned len) final; + void CopyDeviceToHost(void* host_dst, const void* src, unsigned len) final; + void CopyHostToDevice(void* dst, const void* host_src, unsigned len) final; + void* Load(const std::string& data, const std::string& fmt) final; + void Unload(void* mod) final; + void* Resolve(const std::string& sym) final; + void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, unsigned st_num) final; + + static std::string to_string(HEXAPI_Status status); + + private: + static HEX_VA_t p2va(const void* p); + static void* va2p(HEX_VA_t va); + + void CopyFromV(void* host_dst, HEX_VA_t src, unsigned len); + void CopyToV(HEX_VA_t dst, const void* host_src, unsigned len); + + template + void CopyNToV(HEX_VA_t dst, const void* host_src); + template + void CopyNFromV(void* host_dst, HEX_VA_t src); + + // NOLINTNEXTLINE(runtime/references) + void SendMsg(Message& m, const void* data, bool show_dbg); + + std::string arch_; + std::unique_ptr sim_; + HEX_VA_t dispatch_v_, message_buffer_v_; + bool task_queuing_; + + // Sim configuration routines. + bool Configure(string_list& opts); // NOLINT(*) + + bool HandleAHBBusPenalty(string_list& rest); // NOLINT(*) + bool HandleAHBBusRatio(string_list& rest); // NOLINT(*) + bool HandleAHBHighAddr(string_list& rest); // NOLINT(*) + bool HandleAHBLowAddr(string_list& rest); // NOLINT(*) + bool HandleAXI2BusPenalty(string_list& rest); // NOLINT(*) + bool HandleAXI2BusRatio(string_list& rest); // NOLINT(*) + bool HandleAXI2HighAddr(string_list& rest); // NOLINT(*) + bool HandleAXI2LowAddr(string_list& rest); // NOLINT(*) + bool HandleBuildTag(string_list& rest); // NOLINT(*) + bool HandleBusPenalty(string_list& rest); // NOLINT(*) + bool HandleBusRatio(string_list& rest); // NOLINT(*) + bool HandleBusTrace(string_list& rest); // NOLINT(*) + bool HandleBypassIdle(string_list& rest); // NOLINT(*) + bool HandleConnectionTimeout(string_list& rest); // NOLINT(*) + bool HandleCoprocTrace(string_list& rest); // NOLINT(*) + bool HandleCoreDump(string_list& rest); // NOLINT(*) + bool HandleCosimFile(string_list& rest); // NOLINT(*) + bool HandleDCacheTrace(string_list& rest); // NOLINT(*) + bool HandleDSPClock(string_list& rest); // NOLINT(*) + bool HandleETMCFGBase(string_list& rest); // NOLINT(*) + bool HandleGDBServ(string_list& rest); // NOLINT(*) + bool HandleHVXLength(string_list& rest); // NOLINT(*) + bool HandleICacheTrace(string_list& rest); // NOLINT(*) + bool HandleL2CacheTrace(string_list& rest); // NOLINT(*) + bool HandleL2CFGBase(string_list& rest); // NOLINT(*) + bool HandleL2TCMBase(string_list& rest); // NOLINT(*) + bool HandleMemFillRand(string_list& rest); // NOLINT(*) + bool HandleMemFill(string_list& rest); // NOLINT(*) + bool HandleMemTrace(string_list& rest); // NOLINT(*) + bool HandleNullPtr(string_list& rest); // NOLINT(*) + bool HandlePacketAnalyze(string_list& rest); // NOLINT(*) + bool HandlePCFilter(string_list& rest); // NOLINT(*) + bool HandlePCTraceMin(string_list& rest); // NOLINT(*) + bool HandlePCTraceNano(string_list& rest); // NOLINT(*) + bool HandlePCTrace(string_list& rest); // NOLINT(*) + bool HandlePMUStatsFile(string_list& rest); // NOLINT(*) + bool HandleProfile(string_list& rest); // NOLINT(*) + bool HandleProfileTimeZero(string_list& rest); // NOLINT(*) + bool HandleQuiet(string_list& rest); // NOLINT(*) + bool HandleReconnect(string_list& rest); // NOLINT(*) + bool HandleRTOS(string_list& rest); // NOLINT(*) + bool HandleSimErr(string_list& rest); // NOLINT(*) + bool HandleSimIn(string_list& rest); // NOLINT(*) + bool HandleSimOut(string_list& rest); // NOLINT(*) + bool HandleStackStart(string_list& rest); // NOLINT(*) + bool HandleStallTrace(string_list& rest); // NOLINT(*) + bool HandleStatsFile(string_list& rest); // NOLINT(*) + bool HandleSubsystemBase(string_list& rest); // NOLINT(*) + bool HandleSymFile(string_list& rest); // NOLINT(*) + bool HandleTCM(string_list& rest); // NOLINT(*) + bool HandleTCMHighAddr(string_list& rest); // NOLINT(*) + bool HandleTCMLowAddr(string_list& rest); // NOLINT(*) + bool HandleTimeFilterNS(string_list& rest); // NOLINT(*) + bool HandleTiming(string_list& rest); // NOLINT(*) + bool HandleUArchTrace(string_list& rest); // NOLINT(*) + bool HandleUseFS(string_list& rest); // NOLINT(*) + bool HandleV2PTranslation(string_list& rest); // NOLINT(*) + bool HandleVerbose(string_list& rest); // NOLINT(*) + + using MaybeUInt64 = detail::Optional; + using MaybeUIntRange = std::pair; + + bool should_parse_next(const string_list& rest); + detail::Optional to_interval(const detail::MaybeString& str); + detail::Optional to_timingmode(const detail::MaybeString& str); + detail::Optional to_verbosemode(const detail::MaybeString& str); + detail::Optional to_nullptr(const detail::MaybeString& str); + + MaybeUIntRange ahb_, axi2_; + detail::Optional debug_port_; + detail::non_const_str sim_dev_args_; + + using OptionHandler = bool (HexagonSimulator::*)(string_list&); + static std::map opt_map_; +}; + +decltype(HexagonSimulator::opt_map_) HexagonSimulator::opt_map_ = { + {"--ahbbuspenalty", &HexagonSimulator::HandleAHBBusPenalty}, + {"--ahbbusratio", &HexagonSimulator::HandleAHBBusRatio}, + {"--ahb:highaddr", &HexagonSimulator::HandleAHBHighAddr}, + {"--ahb:lowaddr", &HexagonSimulator::HandleAHBLowAddr}, + {"--axi2buspenalty", &HexagonSimulator::HandleAXI2BusPenalty}, + {"--axi2busratio", &HexagonSimulator::HandleAXI2BusRatio}, + {"--axi2:highaddr", &HexagonSimulator::HandleAXI2HighAddr}, + {"--axi2:lowaddr", &HexagonSimulator::HandleAXI2LowAddr}, + {"-b", &HexagonSimulator::HandleBusTrace}, + {"--build_tag", &HexagonSimulator::HandleBuildTag}, + {"--buspenalty", &HexagonSimulator::HandleBusPenalty}, + {"--busratio", &HexagonSimulator::HandleBusRatio}, + {"--bustrace", &HexagonSimulator::HandleBusTrace}, + {"--bypass_idle", &HexagonSimulator::HandleBypassIdle}, + {"--connection_timeout", &HexagonSimulator::HandleConnectionTimeout}, + {"--coproctrace", &HexagonSimulator::HandleCoprocTrace}, + {"--coredump", &HexagonSimulator::HandleCoreDump}, + {"--cosim_file", &HexagonSimulator::HandleCosimFile}, + {"--dcachetrace", &HexagonSimulator::HandleDCacheTrace}, + {"--dsp_clock", &HexagonSimulator::HandleDSPClock}, + {"-E", &HexagonSimulator::HandleSimErr}, + {"--etm_base", &HexagonSimulator::HandleETMCFGBase}, + {"--etmcfg_base", &HexagonSimulator::HandleETMCFGBase}, + {"--gdbserv", &HexagonSimulator::HandleGDBServ}, + {"-G", &HexagonSimulator::HandleGDBServ}, + {"--hvx_length", &HexagonSimulator::HandleHVXLength}, + {"--icachetrace", &HexagonSimulator::HandleICacheTrace}, + {"-I", &HexagonSimulator::HandleSimIn}, + {"--l2cachetrace", &HexagonSimulator::HandleL2CacheTrace}, + {"--l2cfg_base", &HexagonSimulator::HandleL2CFGBase}, + {"--l2tcm_base", &HexagonSimulator::HandleL2TCMBase}, + {"--memfill", &HexagonSimulator::HandleMemFill}, + {"--memfill_rand", &HexagonSimulator::HandleMemFillRand}, + {"--memtrace", &HexagonSimulator::HandleMemTrace}, + {"-m", &HexagonSimulator::HandleMemTrace}, + {"--nullptr", &HexagonSimulator::HandleNullPtr}, + {"-O", &HexagonSimulator::HandleSimOut}, + {"--packet_analyze", &HexagonSimulator::HandlePacketAnalyze}, + {"--pcfilter", &HexagonSimulator::HandlePCFilter}, + {"--pctrace", &HexagonSimulator::HandlePCTrace}, + {"--pctrace_min", &HexagonSimulator::HandlePCTraceMin}, + {"--pctrace_nano", &HexagonSimulator::HandlePCTraceNano}, + {"-p", &HexagonSimulator::HandleProfile}, + {"--pmu_statsfile", &HexagonSimulator::HandlePMUStatsFile}, + {"--profile", &HexagonSimulator::HandleProfile}, + {"--profile_timezero", &HexagonSimulator::HandleProfileTimeZero}, + {"-q", &HexagonSimulator::HandleQuiet}, + {"--quiet", &HexagonSimulator::HandleQuiet}, + {"--reconnect", &HexagonSimulator::HandleReconnect}, + {"--rtos", &HexagonSimulator::HandleRTOS}, + {"-S", &HexagonSimulator::HandleStatsFile}, + {"--sim_err", &HexagonSimulator::HandleSimErr}, + {"--sim_in", &HexagonSimulator::HandleSimIn}, + {"--sim_out", &HexagonSimulator::HandleSimOut}, + {"--stackstart", &HexagonSimulator::HandleStackStart}, + {"--stalltrace", &HexagonSimulator::HandleStallTrace}, + {"--statsfile", &HexagonSimulator::HandleStatsFile}, + {"--subsystem_base", &HexagonSimulator::HandleSubsystemBase}, + {"--symfile", &HexagonSimulator::HandleSymFile}, + {"--tcm", &HexagonSimulator::HandleTCM}, + {"--tcm:highaddr", &HexagonSimulator::HandleTCMHighAddr}, + {"--tcm:lowaddr", &HexagonSimulator::HandleTCMLowAddr}, + {"-t", &HexagonSimulator::HandlePCTrace}, + {"--timefilter_ns", &HexagonSimulator::HandleTimeFilterNS}, + {"--timing", &HexagonSimulator::HandleTiming}, + {"--uarchtrace", &HexagonSimulator::HandleUArchTrace}, + {"-u", &HexagonSimulator::HandlePCTraceMin}, + {"--usefs", &HexagonSimulator::HandleUseFS}, + {"--v2p_translation", &HexagonSimulator::HandleV2PTranslation}, + {"--verbose", &HexagonSimulator::HandleVerbose}, +}; + +#define CHECKED_CALL(func, ...) \ + do { \ + HEXAPI_Status s = sim_->func(__VA_ARGS__); \ + ICHECK_EQ(s, HEX_STAT_SUCCESS) \ + << "HexagonSimulator: " #func " failed with code " << HexagonSimulator::to_string(s); \ + } while (false) + +inline HEX_VA_t HexagonSimulator::p2va(const void* p) { + uintptr_t u = reinterpret_cast(p); + HEX_VA_t va = static_cast(u); + ICHECK_EQ(static_cast(va), u); + return va; +} + +inline void* HexagonSimulator::va2p(HEX_VA_t va) { + return reinterpret_cast(static_cast(va)); +} + +template +constexpr bool is_multiple_of() { + return (N / A) * A == N; +} + +std::shared_ptr CreateHexagonSimulator() { + return std::make_shared(/*enable_queuing=*/true); +} + +template +void HexagonSimulator::CopyNToV(HEX_VA_t dst, const void* host_src) { + using src_uint_t = typename unalign::type>::type; + auto* ps = reinterpret_cast(host_src); + ICHECK_EQ(sim_->WriteVirtual(dst, -1u, N, ps->value), HEX_STAT_SUCCESS); +} + +template +void HexagonSimulator::CopyNFromV(void* host_dst, HEX_VA_t src) { + typename uint::type v; + ICHECK_EQ(sim_->ReadVirtual(src, -1u, N, &v), HEX_STAT_SUCCESS); + + using dst_uint_t = typename unalign::type>::type; + auto* pd = reinterpret_cast(host_dst); + pd->value = v; +} + +void HexagonSimulator::CopyToV(HEX_VA_t dst, const void* host_src, unsigned len) { + const uint8_t* src = static_cast(host_src); + + while (len >= 8) { + CopyNToV<8>(dst, src); + dst += 8; + src += 8; + len -= 8; + } + if (len >= 4) { + CopyNToV<4>(dst, src); + dst += 4; + src += 4; + len -= 4; + } + if (len >= 2) { + CopyNToV<2>(dst, src); + dst += 2; + src += 2; + len -= 2; + } + if (len >= 1) { + CopyNToV<1>(dst, src); + dst++; + src++; + len--; + } + ICHECK_EQ(len, 0); +} + +void HexagonSimulator::CopyFromV(void* host_dst, HEX_VA_t src, unsigned len) { + uint8_t* dst = static_cast(host_dst); + + while (len >= 8) { + CopyNFromV<8>(dst, src); + dst += 8; + src += 8; + len -= 8; + } + if (len >= 4) { + CopyNFromV<4>(dst, src); + dst += 4; + src += 4; + len -= 4; + } + if (len >= 2) { + CopyNFromV<2>(dst, src); + dst += 2; + src += 2; + len -= 2; + } + if (len >= 1) { + CopyNFromV<1>(dst, src); + dst++; + src++; + len--; + } + ICHECK_EQ(len, 0); +} + +void HexagonSimulator::SendMsg(Message& m, const void* data, bool show_dbg) { + auto run = [this](bool report_cycles) { + HEXAPI_CoreState core = HEX_CORE_RESET; + HEX_4u_t result; + HEX_8u_t cycles0, cycles1; + if (report_cycles) { + ICHECK_EQ(sim_->GetSimulatedCycleCount(&cycles0), HEX_STAT_SUCCESS); + } + + core = sim_->Run(&result); + ICHECK_EQ(core, HEX_CORE_BREAKPOINT); + if (report_cycles) { + ICHECK_EQ(sim_->GetSimulatedCycleCount(&cycles1), HEX_STAT_SUCCESS); + LOG(INFO) << "host: execution took " << (cycles1 - cycles0) << " cycles"; + } + }; + + // Send the message request. + Message r = {kMsgReq, m.len, 0u}; + CopyToV(message_buffer_v_, &r, sizeof(r)); + run(false); + + // Receive the acknowledgement with the address for the payload. + CopyFromV(&r, message_buffer_v_, sizeof(r)); + ICHECK_EQ(r.code, kMsgAck); + ICHECK_GE(r.len, m.len); + + // Send the actual message. + m.va = r.va; + CopyToV(message_buffer_v_, &m, sizeof(m)); + if (m.len > 0) CopyToV(r.va, data, m.len); + run(show_dbg); + + // Receive the return data. + CopyFromV(&m, message_buffer_v_, sizeof(m)); + ICHECK_EQ(m.code, kNone); +} + +HexagonSimulator::HexagonSimulator(bool enable_queuing) { + task_queuing_ = enable_queuing; + + // The simulator argument string is in the form: + // + // The optional arguments are seperated with spaces: + // Ex: --hvx_length 128 --memfill 0 --timing -m output.txt + const char* sim_args_env = std::getenv("HEXAGON_SIM_ARGS"); + if (sim_args_env == nullptr) sim_args_env = ""; + auto sim_args_iss = std::istringstream(std::string(sim_args_env)); + using iterator = std::istream_iterator; + auto sim_args = string_list(iterator(sim_args_iss), iterator()); + + std::string target_str = !sim_args.empty() ? *detail::pop_front(sim_args) : std::string("v66"); + + arch_ = target_str; + sim_ = detail::make_unique(detail::non_const_str(target_str)); + LOG(INFO) << "HexagonSimulator: Core version: " << arch_; + + // Locate the sim_dev binary in PATH, or in the current working directory. + std::string sim_dev = "sim_dev"; + detail::MaybeString path_sim_dev = detail::FindInEnvPath("PATH", sim_dev); + if (!path_sim_dev) { + if (!detail::FileExists(sim_dev)) { + LOG(FATAL) << "Cannot find sim_dev in PATH."; + } + path_sim_dev = sim_dev; + } + + CHECKED_CALL(ConfigureExecutableBinary, path_sim_dev->c_str()); + + std::vector app_args = {*path_sim_dev}; + if (char* ev = getenv("ADSP_LIBRARY_PATH")) { + app_args.push_back("-L"); + app_args.push_back(ev); + } + sim_dev_args_ = detail::non_const_str(app_args); + CHECKED_CALL(ConfigureAppCommandLine, sim_dev_args_.size(), sim_dev_args_); + + Configure(sim_args); + + CHECKED_CALL(EndOfConfiguration); + CHECKED_CALL(LoadExecutableBinary); + CHECKED_CALL(ReadSymbolValue, "dispatch", &dispatch_v_); + CHECKED_CALL(ReadSymbolValue, "message_buffer", &message_buffer_v_); + CHECKED_CALL(SetBreakpoint, dispatch_v_); + + HEXAPI_CoreState core = HEX_CORE_RESET; + + HEX_4u_t result; + core = sim_->Run(&result); + if (core != HEX_CORE_BREAKPOINT) { + LOG(FATAL) << "HexagonSimulator: Run not stopped on breakpoint, " + "code=" + << static_cast(core); + } + + // At this point the simulator has executed the executable's initialization + // code that could have written to the SSR register. + // Enable UPCYCLE register. + HEX_4u_t thread_num; + CHECKED_CALL(GetCurrentHWThreadNum, &thread_num); + HEX_4u_t thread_ssr; + CHECKED_CALL(ReadThreadRegister, thread_num, TH_REG_SSR, &thread_ssr); + thread_ssr |= (1 << 23); + CHECKED_CALL(WriteThreadRegister, thread_num, TH_REG_SSR, thread_ssr); +} + +void* HexagonSimulator::Alloc(unsigned size, unsigned align) { + LOG(INFO) << "HexagonSimulator::Alloc(size=" << size << ", align=" << align << ')'; + Message m = {kAlloc, sizeof(MsgAlloc), 0u}; + MsgAlloc ma = {size, align}; + SendMsg(m, &ma, true); + + ICHECK_EQ(sizeof(MsgPointer), m.len); + MsgPointer mp; + CopyFromV(&mp, m.va, m.len); + + LOG(INFO) << "HexagonSimulator::Alloc -> " << std::hex << mp.va << std::dec; + ICHECK_NE(mp.va, 0); + return va2p(mp.va); +} + +void HexagonSimulator::Free(void* ptr) { + LOG(INFO) << "HexagonSimulator::Free(ptr=" << std::hex << ptr << std::dec << ')'; + if (task_queuing_) { + Message mf = {kFlush, 0, 0}; + SendMsg(mf, nullptr, true); + } + Message m = {kFree, sizeof(MsgPointer), 0u}; + MsgPointer mp = {p2va(ptr)}; + SendMsg(m, &mp, true); +} + +void* HexagonSimulator::AllocVtcm(unsigned size, unsigned align) { + LOG(INFO) << "HexagonSimulator::AllocVtcm(size=" << size << ", align=" << align << ')'; + Message m = {kAllocVtcm, sizeof(MsgAlloc), 0u}; + MsgAlloc ma = {size, align}; + SendMsg(m, &ma, true); + + ICHECK_EQ(sizeof(MsgPointer), m.len); + MsgPointer mp; + CopyFromV(&mp, m.va, m.len); + + LOG(INFO) << "HexagonSimulator::AllocVtcm -> " << std::hex << mp.va << std::dec; + ICHECK_NE(mp.va, 0); + return va2p(mp.va); +} + +void HexagonSimulator::FreeVtcm(void* ptr) {} + +void HexagonSimulator::CopyDeviceToDevice(void* dst, const void* src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyDeviceToDevice(dst=" << std::hex << dst << ", src=" << src + << ", len=" << std::dec << len << ')'; + ICHECK(dst != nullptr && src != nullptr); + Message m = {kCopy, sizeof(MsgCopy), 0u}; + MsgCopy mc = {p2va(dst), p2va(src), len}; + SendMsg(m, &mc, true); +} + +void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyDeviceToHost(host_dst=" << host_dst << ", src=" << src + << ", len=" << len << ')'; + if (task_queuing_) { + Message mf = {kFlush, 0, 0}; + SendMsg(mf, nullptr, true); + } + CopyFromV(host_dst, p2va(src), len); +} + +void HexagonSimulator::CopyHostToDevice(void* dst, const void* host_src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyHostToDevice(dst=" << dst << ", host_src=" << host_src + << ", len=" << len << ')'; + CopyToV(p2va(dst), host_src, len); +} + +void* HexagonSimulator::Load(const std::string& data, const std::string& fmt) { + // Load the shared library. + Message m = {kLoad, static_cast(data.size() + 1), 0u}; + SendMsg(m, data.c_str(), false); + + ICHECK_EQ(sizeof(MsgPointer), m.len); + MsgPointer mp; + CopyFromV(&mp, m.va, sizeof(mp)); + + return va2p(mp.va); +} + +void HexagonSimulator::Unload(void* mod) { + ICHECK(mod); + Message m = {kUnload, sizeof(MsgPointer), 0u}; + MsgPointer mp = {p2va(mod)}; + SendMsg(m, &mp, false); +} + +void* HexagonSimulator::Resolve(const std::string& sym) { + LOG(INFO) << "HexagonSimulator::Resolve(sym=" << sym << ')'; + Message m = {kResolve, static_cast(sym.size() + 1), 0u}; + SendMsg(m, sym.c_str(), true); + + ICHECK_EQ(sizeof(MsgPointer), m.len); + MsgPointer mp; + CopyFromV(&mp, m.va, sizeof(mp)); + + LOG(INFO) << "HexagonSimulator::Resolve -> " << std::hex << mp.va << std::dec; + return va2p(mp.va); +} + +void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, + unsigned st_num) { + LOG(INFO) << "HexagonSimulator::Call(func=" << std::hex << func << ", scalar=" << scalar + << ", sc_num=" << std::dec + << sc_num + // NOLINTNEXTLINE(build/include_what_you_use) + << ", stack=" << std::hex << stack << ", st_num=" << std::dec << st_num; + + std::vector data; + + // Copy the MsgCall contents into the data vector as a sequence of uints. + MsgCall me = {p2va(func), sc_num, st_num}; + + ICHECK((is_multiple_of())); + for (unsigned i = 0, e = sizeof(me) / sizeof(uint32_t); i != e; ++i) + data.push_back(reinterpret_cast(&me)[i]); + + // Append the scalar (register) arguments. + for (unsigned i = 0; i != sc_num; ++i) data.push_back(scalar[i]); + // Append the stack contents. + for (unsigned i = 0; i != st_num; ++i) data.push_back(stack[i]); + + std::ostringstream log_data; + log_data << "data: {" << std::hex; + for (unsigned i = 0, e = static_cast(data.size()); i != e; ++i) { + log_data << ' ' << reinterpret_cast(data.data())[i]; + } + log_data << std::dec << " }" << std::flush; + LOG(INFO) << log_data.str(); + + Message m = {kCall, static_cast(data.size() * sizeof(uint32_t)), 0u}; + SendMsg(m, data.data(), true); + + if (!task_queuing_) { + Message mf = {kFlush, 0, 0}; + SendMsg(mf, nullptr, true); + } + + std::vector rv(m.len); + CopyFromV(rv.data(), m.va, m.len); + + std::ostringstream log_rv; + log_rv << "HexagonSimulator::Call -> {" << std::hex; + for (unsigned i = 0, e = std::min(rv.size(), 4u); i != e; ++i) { + log_rv << ' ' << std::setw(2) << std::setfill('0') << static_cast(rv[i]); + } + if (rv.size() > 4) log_rv << "..."; + log_rv << std::dec << " }"; + LOG(INFO) << log_rv.str(); +} + +bool HexagonSimulator::Configure(string_list& opts) { + while (!opts.empty()) { + std::string key = *detail::pop_front(opts); + auto f = opt_map_.find(key); + if (f == opt_map_.end()) { + LOG(FATAL) << "Unrecognized simulator option: " << key; + // unreachable + } + ICHECK((this->*f->second)(opts)) << "error handling option: " << key; + } + + // Check AHB. + if (ahb_.first.has_value() && ahb_.second.has_value()) { + CHECKED_CALL(ConfigureAHB, *ahb_.first, *ahb_.second); + } else { + ICHECK(!ahb_.first.has_value() && !ahb_.second.has_value()) + << "HexagonSimulator: please specify both low and high addresses " + "for AHB"; + } + + // Check AXI2. + if (axi2_.first.has_value() && axi2_.second.has_value()) { + CHECKED_CALL(ConfigureAXI2, *axi2_.first, *axi2_.second); + } else { + ICHECK(!axi2_.first.has_value() && !axi2_.second.has_value()) + << "HexagonSimulator: please specify both low and high addresses " + "for AXI2"; + } + + return true; +} + +bool HexagonSimulator::HandleAHBBusPenalty(string_list& rest) { + auto penalty = detail::to_uint(detail::pop_front(rest)); + auto interval = to_interval(detail::pop_front(rest)); + if (penalty && interval) { + CHECKED_CALL(ConfigureAHBBusPenalty, *penalty, *interval); + } + return static_cast(penalty) && static_cast(interval); +} + +bool HexagonSimulator::HandleAHBBusRatio(string_list& rest) { + auto ratio = detail::to_float(detail::pop_front(rest)); + if (ratio) { + CHECKED_CALL(ConfigureAHBBusRatio, *ratio); + } + return static_cast(ratio); +} + +bool HexagonSimulator::HandleAHBHighAddr(string_list& rest) { + auto addr = detail::to_uint(detail::pop_front(rest)); + ICHECK(addr) << "HexagonSimulator: invalid value for AHB high adddress"; + if (addr) { + ahb_.second = *addr; + } + return static_cast(addr); +} + +bool HexagonSimulator::HandleAHBLowAddr(string_list& rest) { + auto addr = detail::to_uint(detail::pop_front(rest)); + ICHECK(addr) << "HexagonSimulator: invalid value for AHB low adddress"; + if (addr) { + ahb_.first = *addr; + } + return static_cast(addr); +} + +bool HexagonSimulator::HandleAXI2BusPenalty(string_list& rest) { + auto penalty = detail::to_uint(detail::pop_front(rest)); + auto interval = to_interval(detail::pop_front(rest)); + if (penalty && interval) { + CHECKED_CALL(ConfigureAXI2BusPenalty, *penalty, *interval); + } + return static_cast(penalty) && static_cast(interval); +} + +bool HexagonSimulator::HandleAXI2BusRatio(string_list& rest) { + auto ratio = detail::to_float(detail::pop_front(rest)); + if (ratio) { + CHECKED_CALL(ConfigureAXI2BusRatio, *ratio); + } + return static_cast(ratio); +} + +bool HexagonSimulator::HandleAXI2HighAddr(string_list& rest) { + auto addr = detail::to_uint(detail::pop_front(rest)); + ICHECK(addr) << "HexagonSimulator: invalid value for AXI2 high adddress"; + if (addr) { + axi2_.second = *addr; + } + return static_cast(addr); +} + +bool HexagonSimulator::HandleAXI2LowAddr(string_list& rest) { + auto addr = detail::to_uint(detail::pop_front(rest)); + ICHECK(addr) << "HexagonSimulator: invalid value for AXI2 low adddress"; + if (addr) { + axi2_.first = *addr; + } + return static_cast(addr); +} + +bool HexagonSimulator::HandleBuildTag(string_list& rest) { + sim_->PrintBuildTag(); + return true; +} + +bool HexagonSimulator::HandleBusPenalty(string_list& rest) { + auto penalty = detail::to_uint(detail::pop_front(rest)); + auto interval = to_interval(detail::pop_front(rest)); + if (penalty && interval) { + CHECKED_CALL(ConfigureBusPenalty, *penalty, *interval); + } + return static_cast(penalty) && static_cast(interval); +} + +bool HexagonSimulator::HandleBusRatio(string_list& rest) { + auto ratio = detail::to_float(detail::pop_front(rest)); + if (ratio) { + CHECKED_CALL(ConfigureBusRatio, *ratio); + } + return static_cast(ratio); +} + +bool HexagonSimulator::HandleBusTrace(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(SetTracing, HEX_TRACE_BUS, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleBypassIdle(string_list& rest) { + CHECKED_CALL(ConfigureBypassIdle, true); + return true; +} + +bool HexagonSimulator::HandleConnectionTimeout(string_list& rest) { + auto time = detail::to_int(detail::pop_front(rest)); + if (time) { + CHECKED_CALL(ConfigureConnectionTimeout, *time); + } + return static_cast(time); +} + +bool HexagonSimulator::HandleCoprocTrace(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(SetTracing, HEX_TRACE_COPROC, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleCoreDump(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(ConfigureCoreDump, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleCosimFile(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(ConfigureCosim, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleDCacheTrace(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(SetTracing, HEX_TRACE_DCACHE, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleDSPClock(string_list& rest) { + auto freq = detail::to_uint(detail::pop_front(rest)); + if (freq) { + CHECKED_CALL(ConfigureCoreFrequency, *freq); + } + return static_cast(freq); +} + +bool HexagonSimulator::HandleETMCFGBase(string_list& rest) { + auto base = detail::to_uint(detail::pop_front(rest)); + if (base) { + CHECKED_CALL(ConfigureEtmcfgBase, *base); + } + return static_cast(base); +} + +bool HexagonSimulator::HandleGDBServ(string_list& rest) { + auto port = detail::to_uint(detail::pop_front(rest)); + if (port) { + CHECKED_CALL(ConfigureRemoteDebug, *port); + debug_port_ = *port; + } + return static_cast(port); +} + +bool HexagonSimulator::HandleHVXLength(string_list& rest) { + auto len = detail::to_int(detail::pop_front(rest)); + if (len) { + CHECKED_CALL(ConfigureHVXLength, *len); + } + return static_cast(len); +} + +bool HexagonSimulator::HandleICacheTrace(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(SetTracing, HEX_TRACE_ICACHE, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleL2CacheTrace(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(SetTracing, HEX_TRACE_L2CACHE, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleL2CFGBase(string_list& rest) { + auto base = detail::to_uint(detail::pop_front(rest)); + if (base) { + CHECKED_CALL(ConfigureL2cfgBase, *base); + } + return static_cast(base); +} + +bool HexagonSimulator::HandleL2TCMBase(string_list& rest) { + auto base = detail::to_uint(detail::pop_front(rest)); + if (base) { + CHECKED_CALL(ConfigureL2tcmBase, *base); + } + return static_cast(base); +} + +bool HexagonSimulator::HandleMemFillRand(string_list& rest) { + auto seed = detail::to_uint(detail::pop_front(rest)); + if (seed) { + CHECKED_CALL(ConfigureMemFillRandom, *seed); + } + return static_cast(seed); +} + +bool HexagonSimulator::HandleMemFill(string_list& rest) { + auto val = detail::to_uint(detail::pop_front(rest)); + if (val) { + CHECKED_CALL(ConfigureMemFill, *val); + } + return static_cast(val); +} + +bool HexagonSimulator::HandleMemTrace(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(SetTracing, HEX_TRACE_MEM, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleNullPtr(string_list& rest) { + auto behavior = to_nullptr(detail::pop_front(rest)); + if (behavior) { + CHECKED_CALL(ConfigureNULLPointerBehavior, *behavior); + } + return static_cast(behavior); +} + +bool HexagonSimulator::HandlePacketAnalyze(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(ConfigurePacketAnalysis, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandlePCFilter(string_list& rest) { + auto range = detail::to_range(detail::pop_front(rest)); + if (range) { + CHECKED_CALL(ConfigurePCRangeFilter, range->first, range->second); + } + return static_cast(range); +} + +bool HexagonSimulator::HandlePCTraceMin(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(SetTracing, HEX_TRACE_PC_MIN, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandlePCTraceNano(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(SetTracing, HEX_TRACE_PC_NANO, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandlePCTrace(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(SetTracing, HEX_TRACE_PC, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandlePMUStatsFile(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(ConfigurePmuStatisticsFile, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleProfile(string_list& rest) { + auto path = detail::pop_front(rest); + if (path) { + CHECKED_CALL(ConfigureGProf, path->c_str()); + } + return static_cast(path); +} + +bool HexagonSimulator::HandleProfileTimeZero(string_list& rest) { + auto timezero = detail::to_bool(detail::pop_front(rest)); + if (timezero) { + CHECKED_CALL(ConfigureProfileMode, *timezero); + } + return static_cast(timezero); +} + +bool HexagonSimulator::HandleQuiet(string_list& rest) { + sim_->VerboseMode(HEX_QUIET); + return true; +} + +bool HexagonSimulator::HandleReconnect(string_list& rest) { + if (!debug_port_) { + LOG(FATAL) << "Reconnect error: --reconnect must be specified " + "AFTER --gdbserv "; + } + CHECKED_CALL(ConfigureRemoteDebug, *debug_port_, true); + return true; +} + +bool HexagonSimulator::HandleRTOS(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(ConfigureOSAwareness, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleSimErr(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(ConfigureSimStderr, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleSimIn(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(ConfigureSimStdin, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleSimOut(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(ConfigureSimStdout, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleStackStart(string_list& rest) { + auto base = detail::to_uint(detail::pop_front(rest)); + auto size = detail::to_uint(detail::pop_front(rest)); + if (base && size) { + CHECKED_CALL(ConfigureStackInfo, *base, *size); + } + return static_cast(base) && static_cast(size); +} + +bool HexagonSimulator::HandleStallTrace(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(SetTracing, HEX_TRACE_STALL, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleStatsFile(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(ConfigureStatisticsFile, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleSubsystemBase(string_list& rest) { + auto base = detail::to_uint(detail::pop_front(rest)); + if (base) { + CHECKED_CALL(ConfigureSubsystemBase, *base); + } + return static_cast(base); +} + +bool HexagonSimulator::HandleSymFile(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(AddSymbolFile, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleTCM(string_list& rest) { + CHECKED_CALL(ConfigureTimingMode, HEX_TIMING); + return true; +} + +bool HexagonSimulator::HandleTCMHighAddr(string_list& rest) { + // This option takes an argument, but (the option) is ignored. + auto addr = detail::to_uint(detail::pop_front(rest)); + return static_cast(addr); +} + +bool HexagonSimulator::HandleTCMLowAddr(string_list& rest) { + auto addr = detail::to_uint(detail::pop_front(rest)); + if (addr) { + CHECKED_CALL(ConfigureTCM, *addr); + } + return static_cast(addr); +} + +bool HexagonSimulator::HandleTimeFilterNS(string_list& rest) { + auto range = detail::to_range(detail::pop_front(rest)); + if (range) { + CHECKED_CALL(ConfigureTimeRangeFilter, range->first, HEX_NANOSEC, range->second, HEX_NANOSEC); + } + return static_cast(range); +} + +bool HexagonSimulator::HandleTiming(string_list& rest) { + HEXAPI_TimingMode timing_mode = HEX_TIMING; + // The argument to --timing is optional. + if (should_parse_next(rest)) { + if (auto mode = to_timingmode(detail::pop_front(rest))) { + timing_mode = *mode; + } else { + return false; + } + } + CHECKED_CALL(ConfigureTimingMode, timing_mode); + return true; +} + +bool HexagonSimulator::HandleUArchTrace(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(SetTracing, HEX_TRACE_UARCH, file->c_str()); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleUseFS(string_list& rest) { + auto file = detail::pop_front(rest); + if (file) { + CHECKED_CALL(ConfigureARFilesystem, detail::non_const_str(*file)); + } + return static_cast(file); +} + +bool HexagonSimulator::HandleV2PTranslation(string_list& rest) { + auto enable = detail::to_bool(detail::pop_front(rest)); + if (enable) { + CHECKED_CALL(EnableVirtualToPhysicalTranslation, *enable); + } + return static_cast(enable); +} + +bool HexagonSimulator::HandleVerbose(string_list& rest) { + auto mode = to_verbosemode(detail::pop_front(rest)); + if (mode) { + sim_->VerboseMode(*mode); + } + return static_cast(mode); +} + +bool HexagonSimulator::should_parse_next(const string_list& rest) { + if (auto str = detail::front(rest)) { + return str->empty() || str->front() != '-'; + } + return false; +} + +detail::Optional HexagonSimulator::to_interval(const detail::MaybeString& str) { + auto none = detail::Optional(); + if (!str) return none; + + if (auto val = detail::to_int(*str)) { + switch (*val) { + case HEX_MILLISEC: + case HEX_MICROSEC: + case HEX_NANOSEC: + case HEX_PICOSEC: + case HEX_PCYCLE: + return static_cast(*val); + } + } + + return detail::StringSwitch>(*str) + .Case("MILLISEC", HEX_MILLISEC) + .Case("MICROSEC", HEX_MICROSEC) + .Case("NANOSEC", HEX_NANOSEC) + .Case("PICOSEC", HEX_PICOSEC) + .Case("PCYCLE", HEX_PCYCLE) + .Default(none); +} + +detail::Optional HexagonSimulator::to_timingmode( + const detail::MaybeString& str) { + auto none = detail::Optional(); + if (!str) return none; + + if (auto val = detail::to_int(*str)) { + switch (*val) { + case HEX_NOTIMING: + case HEX_TIMING_NODBC: + case HEX_TIMING: + case HEX_TIMING_COHERENCY: + return static_cast(*val); + } + } + + return detail::StringSwitch>(*str) + .Case("NOTIMING", HEX_NOTIMING) + .Case("TIMING_NODBC", HEX_TIMING_NODBC) + .Case("TIMING", HEX_TIMING) + .Case("TIMING_COHERENCY", HEX_TIMING_COHERENCY) + .Default(none); +} + +detail::Optional HexagonSimulator::to_verbosemode( + const detail::MaybeString& str) { + auto none = detail::Optional(); + if (!str) return none; + + if (auto val = detail::to_int(*str)) { + switch (*val) { + case HEX_SILENT: + case HEX_QUIET: + case HEX_NORMAL: + case HEX_VERBOSE: + case HEX_REALLY_VERBOSE: + return static_cast(*val); + } + } + + return detail::StringSwitch>(*str) + .Case("SILENT", HEX_SILENT) + .Case("QUIET", HEX_QUIET) + .Case("NORMAL", HEX_NORMAL) + .Case("VERBOSE", HEX_VERBOSE) + .Case("REALLY_VERBOSE", HEX_REALLY_VERBOSE) + .Default(none); +} + +detail::Optional HexagonSimulator::to_nullptr(const detail::MaybeString& str) { + auto none = detail::Optional(); + if (!str) return none; + + if (auto val = detail::to_int(*str)) { + switch (*val) { + case HEX_NULLPTR_IGNORE: + case HEX_NULLPTR_WARN: + case HEX_NULLPTR_FATAL: + case HEX_NULLPTR_PCZERO: + return static_cast(*val); + } + } + + return detail::StringSwitch>(*str) + .Case("IGNORE", HEX_NULLPTR_IGNORE) + .Case("WARN", HEX_NULLPTR_WARN) + .Case("FATAL", HEX_NULLPTR_FATAL) + .Case("PCZERO", HEX_NULLPTR_PCZERO) + .Default(none); +} + +std::string HexagonSimulator::to_string(HEXAPI_Status status) { + switch (status) { + case HEX_STAT_ERROR: + return "ERROR"; + case HEX_STAT_SUCCESS: + return "SUCCESS"; + case HEX_STAT_CANNOT_CONFIG: + return "CANNOT_CONFIG"; + case HEX_STAT_INVALID_ARGS: + return "INVALID_ARGS"; + case HEX_STAT_RANGE_ERROR: + return "RANGE_ERROR"; + case HEX_STAT_FILE_ACCESS_ERROR: + return "FILE_ACCESS_ERROR"; + case HEX_STAT_DEVICE_NOT_FOUND: + return "DEVICE_NOT_FOUND"; + case HEX_STAT_MEM_ACCESS_ERROR: + return "MEM_ACCESS_ERROR"; + case HEX_STAT_CANNOT_TRANSLATE: + return "CANNOT_TRANSLATE"; + case HEX_STAT_NO_ACTIVE_THREADS: + return "NO_ACTIVE_THREADS"; + case HEX_STAT_LOAD_ELF_ERROR: + return "LOAD_ELF_ERROR"; + case HEX_STAT_CORE_RESET: + return "CORE_RESET"; + default: + return "unknown"; + } +} + +} // namespace hexagon +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index aac75002c2..160bd1de74 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -64,6 +64,7 @@ class HexagonModuleNode : public runtime::ModuleNode { const char* type_key() const final { return "hexagon"; } void SaveToFile(const std::string& file_name, const std::string& format) override; void SaveToBinary(dmlc::Stream* stream) override; + bool IsDSOExportable() const final { return true; } protected: std::string data_; diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc new file mode 100644 index 0000000000..c699d22916 --- /dev/null +++ b/src/runtime/relax_vm/builtin.cc @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/runtime/relax_vm/builtin.cc + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +using tvm::runtime::NDArray; + +TVM_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_method(&NDArray::Shape); + +TVM_REGISTER_GLOBAL("vm.builtin.copy").set_body_typed([](NDArray src) { return src; }); + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap") + .set_body_typed([](void* vm_ptr, ShapeTuple size) { + VirtualMachine* vm = static_cast(vm_ptr); + return NDArray::Empty(size, DLDataType{kDLInt, 64, 1}, vm->devices[0]); + }); + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_closure").set_body([](TVMArgs args, TVMRetValue* rv) { + std::vector cap_vars; + for (int i = 1; i < args.size(); ++i) { + cap_vars.push_back(args[i]); + } + String func_name = args[0]; + VMClosure vm_closure(func_name, cap_vars); + + *rv = std::move(vm_closure); +}); + +TVM_REGISTER_GLOBAL("vm.builtin.invoke_closure").set_body([](TVMArgs args, TVMRetValue* rv) { + // args[0]: vm; args[1]: closure; args[2, 3, ...]: function arguments + void* vm_ptr = args[0]; + VirtualMachine* vm = static_cast(vm_ptr); + VMClosure vm_closure = args[1]; + runtime::String func_name = vm_closure->func_name; + + PackedFunc func{nullptr}; + func = vm->GetFunction(func_name, GetObjectPtr(vm)); + ICHECK(func != nullptr) << "cannot find closure " << func_name; + + // get closure free_vars + Array cap_vars = vm_closure->free_vars; + size_t num_tensor_args = args.size() - 2; + std::vector values(num_tensor_args + cap_vars.size()); + std::vector tcodes(num_tensor_args + cap_vars.size()); + + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + for (size_t i = 0; i < num_tensor_args; i++) { + NDArray arg = args[i + 2]; + setter(i, arg); + } + for (size_t i = 0; i < cap_vars.size(); i++) { + setter(i + num_tensor_args, cap_vars[i]); + } + TVMArgs func_args(values.data(), tcodes.data(), values.size()); + func.CallPacked(func_args, rv); +}); + +TVM_REGISTER_GLOBAL("vm.builtin.store_shape") + .set_body_typed([](ShapeTuple shape, NDArray heap, ShapeTuple indexes) { + int64_t* heap_data = reinterpret_cast(heap.ToDLPack()->dl_tensor.data); + for (size_t i = 0; i < indexes.size(); ++i) { + int64_t heap_idx = indexes[i]; + ICHECK(heap_idx >= 0 && heap_idx < heap.Shape()[0]); + heap_data[heap_idx] = shape[i]; + } + }); + +TVM_REGISTER_GLOBAL("vm.builtin.load_shape").set_body_typed([](NDArray heap, ShapeTuple indexes) { + int64_t* heap_data = reinterpret_cast(heap.ToDLPack()->dl_tensor.data); + std::vector shape; + for (size_t i = 0; i < indexes.size(); ++i) { + int64_t heap_idx = indexes[i]; + ICHECK(heap_idx >= 0 && heap_idx < heap.Shape()[0]); + shape.push_back(heap_data[heap_idx]); + } + return ShapeTuple(shape); +}); + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage") + .set_body_typed([](void* vm_ptr, ShapeTuple buffer_size, Index device_index, + DLDataType dtype_hint) { + ICHECK_EQ(buffer_size.size(), 1); + int alignment = runtime::kAllocAlignment; + VirtualMachine* vm = static_cast(vm_ptr); + ICHECK_LT(device_index, vm->devices.size()) + << "The device index is out of VM physical devices list"; + + if (device_index == -1) { + // Allocate on host. Host is always the last element of vm->devices. + device_index = vm->devices.size() - 1; + } + + int64_t size_imm = buffer_size[0]; + + auto storage_obj = runtime::SimpleObjAllocator().make_object(); + auto* alloc = vm->allocators[device_index]; + ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; + storage_obj->buffer = alloc->Alloc(size_imm, alignment, dtype_hint); + Storage storage(storage_obj); + return storage; + }); + +TVM_REGISTER_GLOBAL("vm.builtin.alloc_tensor").set_body_method(&StorageObj::AllocNDArray); + +TVM_REGISTER_GLOBAL("vm.binary_broadcast_shape_infer") + .set_body_typed([](ShapeTuple lhs_shape, ShapeTuple rhs_shape) { + std::vector output_shape; + size_t ndim0 = lhs_shape.size(); + size_t ndim1 = rhs_shape.size(); + size_t i = 1; + for (; i <= std::min(ndim0, ndim1); ++i) { + int64_t lhs_dim = lhs_shape[ndim0 - i]; + int64_t rhs_dim = rhs_shape[ndim1 - i]; + ICHECK(lhs_dim == rhs_dim || lhs_dim == 1 || rhs_dim == 1); + output_shape.push_back(std::max(lhs_dim, rhs_dim)); + } + size_t max_ndim = std::max(ndim0, ndim1); + ShapeTuple& longer_shape = (ndim0 > ndim1) ? lhs_shape : rhs_shape; + for (; i <= max_ndim; ++i) { + output_shape.push_back(longer_shape[max_ndim - i]); + } + return ShapeTuple(output_shape.rbegin(), output_shape.rend()); + }); + +TVM_REGISTER_GLOBAL("vm.call_tir_dyn").set_body([](TVMArgs args, TVMRetValue* rv) { + void* vm_ptr = args[0]; + VirtualMachine* vm = static_cast(vm_ptr); + runtime::String func_name = args[1]; + + PackedFunc func{nullptr}; + if (vm->lib.defined()) { + func = vm->lib.value()->GetFunction(func_name, true); + } + if (!func.defined()) { + const PackedFunc* p_func = Registry::Get(func_name); + CHECK(p_func != nullptr); + func = *(p_func); + } + + ShapeTuple to_unpack = args[args.size() - 1]; + size_t num_tensor_args = args.size() - 3; + std::vector values(num_tensor_args + to_unpack.size()); + std::vector tcodes(num_tensor_args + to_unpack.size()); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + for (size_t i = 0; i < num_tensor_args; i++) { + NDArray arg = args[i + 2]; + setter(i, arg); + } + for (size_t i = 0; i < to_unpack.size(); i++) { + setter(i + num_tensor_args, to_unpack[i]); + } + + TVMArgs func_args(values.data(), tcodes.data(), values.size()); + func.CallPacked(func_args, rv); +}); + +TVM_REGISTER_GLOBAL("vm.runtime.TupleGetItem") + .set_body_typed([](runtime::ADT adt, ShapeTuple index) { + ICHECK_EQ(index.size(), 1); + int idx = index[0]; + ICHECK_LT(idx, adt.size()); + return adt[idx]; + }); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/bytecode.cc b/src/runtime/relax_vm/bytecode.cc new file mode 100644 index 0000000000..9084207848 --- /dev/null +++ b/src/runtime/relax_vm/bytecode.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/relax_vm/bytecode.cc + * \brief The bytecode for Relax virtual machine. + */ + +#include +#include + +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +Instruction Instruction::Call(Index func_idx, Index num_args, Instruction::Arg* args, RegName dst) { + Instruction instr; + instr.op = Opcode::Call; + instr.dst = dst; + instr.func_idx = func_idx; + instr.num_args = num_args; + instr.args = args; + return instr; +} + +Instruction Instruction::Ret(RegName result) { + Instruction instr; + instr.op = Opcode::Ret; + instr.result = result; + return instr; +} + +Instruction Instruction::Goto(Index pc_offset) { + Instruction instr; + instr.op = Opcode::Goto; + instr.pc_offset = pc_offset; + return instr; +} + +Instruction Instruction::If(RegName cond, Index false_offset) { + Instruction instr; + instr.op = Opcode::If; + instr.cond = cond; + instr.false_offset = false_offset; + return instr; +} +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc new file mode 100644 index 0000000000..db9a278760 --- /dev/null +++ b/src/runtime/relax_vm/executable.cc @@ -0,0 +1,578 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/relax_vm/executable.cc + */ + +#include +#include +#include +#include + +#include +#include + +#include "../file_utils.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +/*! \brief The magic number for the serialized VM bytecode file */ +constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D; + +/*! \brief Possible types in the constant pool */ +enum ConstantType : int { + kNDArray = 0, + kDLDataType = 1, + kShapeTuple = 2, + kString = 3, + kInt = 4, +}; + +#define STREAM_CHECK(val, section) \ + ICHECK(val) << "Invalid VM file format in the " << section << " section." \ + << "\n"; + +TVM_REGISTER_OBJECT_TYPE(VMClosureObj); + +VMClosure::VMClosure(String func_name, Array free_vars) { + auto ptr = make_object(); + ptr->func_name = func_name; + ptr->free_vars = std::move(free_vars); + data_ = std::move(ptr); +} + +PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "stats") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); + } else if (name == "as_text") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsText(); }); + } else if (name == "as_python") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->AsPython(); }); + } else if (name == "vm_load_executable") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ObjectPtr vm = make_object(); + ICHECK(sptr_to_self.get() == this); + vm->LoadExecutable(GetObjectPtr(this)); + *rv = Module(vm); + }); + } + return nullptr; +} + +std::string Executable::Stats() const { + std::ostringstream oss; + oss << "Relax VM executable statistics:" << std::endl; + + // Get the number of constants. + // If the constant is an NDArray, get the shape of each of them. + // If the constant is an DLDataType, get the data type of each of them. + oss << " Constant pool (# " << constants.size() << "): ["; + for (const auto& it : constants) { + if (it.IsObjectRef()) { + const auto ndarray = it.operator tvm::runtime::NDArray(); + const auto& shape = ndarray.Shape(); + // Scalar + if (shape.empty()) { + oss << "scalar, "; + continue; + } + oss << "["; + for (auto s : shape) { + oss << s << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], "; + } else if (it.IsObjectRef()) { + ShapeTuple shape = it.operator ShapeTuple(); + oss << "shapetuple["; + for (size_t i = 0; i < shape.size(); ++i) { + oss << shape.at(i) << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], "; + } else if (it.IsObjectRef()) { + std::string f = it.AsObjectRef().operator std::string(); + oss << "\""; + oss << f; + oss << "\", "; + } else if (it.type_code() == kDLInt) { + oss << static_cast(it); + oss << ", "; + } else { + try { + DLDataType dtype = it.operator DLDataType(); + oss << dtype; + oss << ", "; + } catch (std::exception& exc) { + LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " + << ArgTypeCode2Str(it.type_code()); + } + } + } + if (!constants.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of globals and the name of each of them. + oss << " Globals (#" << global_funcs.size() << "): ["; + for (const auto& it : global_funcs) { + oss << it.name << ", "; + } + if (!global_map.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of packed funcs and the name of each of them. + oss << " Packed functions (#" << func_names.size() << "): ["; + for (const auto& it : func_names) { + oss << it << ", "; + } + if (!func_names.empty()) { + oss.seekp(-2, oss.cur); + } + oss << "]" << std::endl; + + return oss.str(); +} + +void Executable::SetInstructionData(Index i, Index j, ExecWord val) { + Index instr_idx = instr_offset[i]; + instr_data[instr_idx + j] = val; +} + +Instruction Executable::GetInstruction(Index i) const { + Index offset = instr_offset[i]; + Opcode op = static_cast(instr_data[offset]); + switch (op) { + case Opcode::Call: { + RegName dst = instr_data[offset + 1]; + Index func_idx = instr_data[offset + 2]; + Index num_args = instr_data[offset + 3]; + ExecWord* args = const_cast(&instr_data[offset + 4]); + return Instruction::Call(func_idx, num_args, reinterpret_cast(args), dst); + } + case Opcode::Ret: { + RegName result = instr_data[offset + 1]; + return Instruction::Ret(result); + } + case Opcode::Goto: { + Index pc_offset = instr_data[offset + 1]; + return Instruction::Goto(pc_offset); + } + case Opcode::If: { + RegName cond = instr_data[offset + 1]; + Index false_offset = instr_data[offset + 2]; + return Instruction::If(cond, false_offset); + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(op); + break; + } + return Instruction(); +} + +void SaveHeader(dmlc::Stream* strm) { + uint64_t header = kTVMVMBytecodeMagic; + strm->Write(header); + std::string version = TVM_VERSION; + strm->Write(version); +} + +void LoadHeader(dmlc::Stream* strm) { + // Check header. + uint64_t header; + STREAM_CHECK(strm->Read(&header), "header"); + STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); + + // Check version. + std::string version; + STREAM_CHECK(strm->Read(&version), "version"); + STREAM_CHECK(version == TVM_VERSION, "version"); +} + +void Executable::SaveToBinary(dmlc::Stream* stream) { + std::string code; + // Initialize the stream object. + dmlc::MemoryStringStream strm(&code); + + // Save header + SaveHeader(&strm); + + // Global section. + SaveGlobalSection(&strm); + + // Constant section. + SaveConstantSection(&strm); + + // Packedfunc names section. + SavePackedFuncNames(&strm); + + // Code section. + SaveCodeSection(&strm); + + stream->Write(code); +} + +void Executable::SaveToFile(const std::string& file_name, const std::string& format) { + std::string data; + dmlc::MemoryStringStream writer(&data); + dmlc::SeekStream* strm = &writer; + Executable::SaveToBinary(strm); + runtime::SaveBinaryToFile(file_name, data); +} + +Module Executable::LoadFromBinary(void* stream) { + std::string code; + static_cast(stream)->Read(&code); + dmlc::MemoryStringStream strm(&code); + + ObjectPtr exec = make_object(); + + // Load header. + LoadHeader(&strm); + + // Global section. + exec->LoadGlobalSection(&strm); + + // Constant section. + exec->LoadConstantSection(&strm); + + // Packedfunc names section. + exec->LoadPackedFuncNames(&strm); + + // Code section. + exec->LoadCodeSection(&strm); + + return Module(exec); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_relax.Executable") + .set_body_typed(Executable::LoadFromBinary); + +Module Executable::LoadFromFile(const std::string& file_name) { + std::string data; + runtime::LoadBinaryFromFile(file_name, &data); + dmlc::MemoryStringStream reader(&data); + dmlc::Stream* strm = &reader; + return Executable::LoadFromBinary(reinterpret_cast(strm)); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_relax.Executable") + .set_body_typed(Executable::LoadFromFile); + +void SerializeVMFunc(const VMFunction& func, dmlc::Stream* strm) { + strm->Write(func.name); + strm->Write(func.start_instr); + strm->Write(func.num_args); + strm->Write(func.register_file_size); + strm->Write(func.param_names); +} + +VMFunction DeserializeVMFunc(dmlc::Stream* strm) { + VMFunction func; + STREAM_CHECK(strm->Read(&func.name), "vmfunc name"); + STREAM_CHECK(strm->Read(&func.start_instr), "vmfunc start_instr"); + STREAM_CHECK(strm->Read(&func.num_args), "vmfunc num_args"); + STREAM_CHECK(strm->Read(&func.register_file_size), "vmfunc register_file_size"); + STREAM_CHECK(strm->Read(&func.param_names), "vmfunc params"); + return func; +} + +void Executable::SaveGlobalSection(dmlc::Stream* strm) { + strm->Write(static_cast(this->global_funcs.size())); + for (const auto& func : this->global_funcs) { + SerializeVMFunc(func, strm); + } +} + +void Executable::SaveConstantSection(dmlc::Stream* strm) { + strm->Write(static_cast(this->constants.size())); + for (const auto& it : this->constants) { + if (it.IsObjectRef()) { + strm->Write(ConstantType::kNDArray); + runtime::SaveDLTensor(strm, it.operator DLTensor*()); + } else if (it.IsObjectRef()) { + ShapeTuple shape = it.operator ShapeTuple(); + strm->Write(ConstantType::kShapeTuple); + strm->Write(shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + strm->Write(shape.at(i)); + } + } else if (it.IsObjectRef()) { + String str = it.operator String(); + strm->Write(ConstantType::kString); + strm->Write(str.size()); + for (size_t i = 0; i < str.size(); ++i) { + strm->Write(str.at(i)); + } + } else if (it.type_code() == kDLInt) { + strm->Write(ConstantType::kInt); + strm->Write(it.value()); + } else { + try { + strm->Write(ConstantType::kDLDataType); + strm->Write(it.operator DLDataType()); + } catch (std::exception& exc) { + LOG(FATAL) << "Constant pool can only contain NDArray, DLDataType, and Integers but got " + << ArgTypeCode2Str(it.type_code()); + } + } + } +} + +void Executable::SavePackedFuncNames(dmlc::Stream* strm) { strm->Write(func_names); } + +void Executable::SaveCodeSection(dmlc::Stream* strm) { + strm->Write(instr_offset); + strm->Write(instr_data); +} + +void Executable::LoadGlobalSection(dmlc::Stream* strm) { + uint64_t sz; + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); + size_t size = static_cast(sz); + for (size_t i = 0; i < size; i++) { + VMFunction func = DeserializeVMFunc(strm); + this->global_funcs.push_back(func); + } + for (size_t i = 0; i < global_funcs.size(); ++i) { + this->global_map[global_funcs[i].name] = i; + } +} + +void Executable::LoadConstantSection(dmlc::Stream* strm) { + uint64_t sz; + // Load the number of constants. + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); + + size_t size = static_cast(sz); + runtime::NDArray ndarray; + DLDataType dtype; + // Load each of the constants. + for (size_t i = 0; i < size; i++) { + int constant_type; + STREAM_CHECK(strm->Read(&constant_type, sizeof(constant_type)), "constant"); + if (constant_type == ConstantType::kNDArray) { + ndarray.Load(strm); + TVMRetValue cell; + cell = ndarray; + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kShapeTuple) { + uint64_t size; + strm->Read(&size); + std::vector data(size); + for (size_t i = 0; i < size; ++i) { + strm->Read(&(data[i])); + } + TVMRetValue cell; + cell = ShapeTuple(data); + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kDLDataType) { + strm->Read(&dtype); + TVMRetValue cell; + cell = dtype; + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kString) { + uint64_t size; + strm->Read(&size); + std::vector data(size); + for (size_t i = 0; i < size; ++i) { + strm->Read(&(data[i])); + } + TVMRetValue cell; + cell = String(std::string(data.begin(), data.end())); + this->constants.push_back(cell); + } else if (constant_type == ConstantType::kInt) { + int64_t value; + strm->Read(&value); + TVMRetValue cell; + cell = value; + this->constants.push_back(cell); + } else { + LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " + << ArgTypeCode2Str(constant_type) << " when loading the VM constant pool."; + } + } +} + +void Executable::LoadPackedFuncNames(dmlc::Stream* strm) { + STREAM_CHECK(strm->Read(&(this->func_names)), "packed func names"); + for (size_t i = 0; i < func_names.size(); ++i) { + this->func2idx[func_names[i]] = i; + } +} + +void Executable::LoadCodeSection(dmlc::Stream* strm) { + STREAM_CHECK(strm->Read(&(this->instr_offset)), "instr offset"); + STREAM_CHECK(strm->Read(&(this->instr_data)), "instr data"); +} + +template +std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ", + std::function repr = std::to_string) { + if (cnt == 0) { + return ""; + } + std::ostringstream oss; + oss << repr(items[offset]); + for (int i = 1; i < cnt; ++i) { + oss << delim << repr(items[offset + i]); + } + return oss.str(); +} + +std::string RegNameToStr(RegName reg) { + if (reg == Instruction::kVoidArg) { + return "void"; + } + if (reg == Instruction::kVMRegister) { + return "%vm"; + } + return "%" + std::to_string(reg); +} + +std::string InstrArgToStr(Instruction::Arg arg) { + // only for argument + switch (arg.kind()) { + case Instruction::kRegister: + return RegNameToStr(arg.value()); + case Instruction::kImmediate: + return "i" + std::to_string(arg.value()); + case Instruction::kConstIdx: + return "c[" + std::to_string(arg.value()) + "]"; + default: + LOG(FATAL) << "Wrong instruction kind: " << arg.kind(); + return ""; + } +} + +std::string InstrArgToPyStr(Instruction::Arg arg) { + switch (arg.kind()) { + case Instruction::kRegister: + if (arg.value() == Instruction::kVMRegister) { + return "ib.r(vm)"; + } + return "ib.r(" + std::to_string(arg.value()) + ")"; + case Instruction::kImmediate: + return "ib.imm(" + std::to_string(arg.value()) + ")"; + case Instruction::kConstIdx: + return "ib.c(" + std::to_string(arg.value()) + ")"; + default: + LOG(FATAL) << "Wrong instruction kind: " << arg.kind(); + return ""; + } +} + +String Executable::AsText() const { + // print the text format + std::ostringstream os; + for (size_t fidx = 0; fidx < this->global_funcs.size(); ++fidx) { + const VMFunction& gfunc = this->global_funcs[fidx]; + os << "@" << gfunc.name << ":\n"; + size_t start_instr = gfunc.start_instr; + size_t end_instr = this->instr_offset.size(); + if ((fidx + 1) < global_funcs.size()) { + end_instr = global_funcs[fidx + 1].start_instr; + } + for (size_t idx = start_instr; idx < end_instr; ++idx) { + os << " "; + Instruction instr = this->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + os << std::setw(6) << std::left << "call" << std::setw(16) << std::left + << this->func_names[instr.func_idx] << " in: " << std::setw(12) << std::left + << StrJoin(instr.args, 0, instr.num_args, ", ", InstrArgToStr) + << " dst: " << RegNameToStr(instr.dst) << "\n"; + break; + } + case Opcode::Ret: { + os << std::setw(6) << std::left << "ret " << RegNameToStr(instr.result) << "\n"; + break; + } + case Opcode::Goto: { + os << std::setw(6) << std::left << "goto" << instr.pc_offset << "\n"; + break; + } + case Opcode::If: { + os << std::setw(6) << std::left << "If" << RegNameToStr(instr.cond) << ", " + << instr.false_offset << "\n"; + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + os << "\n"; + } + return String(os.str()); +} + +String Executable::AsPython() const { + // print the python format + std::ostringstream os; + os << "ib = rx.Builder()\n"; + for (size_t fidx = 0; fidx < this->global_funcs.size(); ++fidx) { + const VMFunction& gfunc = this->global_funcs[fidx]; + os << "with ib.function(\"" << gfunc.name << "\", num_inputs=" << gfunc.num_args << "):\n"; + size_t start_instr = gfunc.start_instr; + size_t end_instr = this->instr_offset.size(); + if ((fidx + 1) < global_funcs.size()) { + end_instr = global_funcs[fidx + 1].start_instr; + } + for (size_t idx = start_instr; idx < end_instr; ++idx) { + Instruction instr = this->GetInstruction(idx); + switch (instr.op) { + case Opcode::Call: { + os << " ib.emit_call(\"" << this->func_names[instr.func_idx] << "\", args=[" + << StrJoin(instr.args, 0, instr.num_args, ", ", InstrArgToPyStr) + << "]"; + if (instr.dst != Instruction::kVoidArg) os << ", dst=ib.r(" << instr.dst << ")"; + os << ")\n"; + break; + } + case Opcode::Ret: { + os << " ib.emit_ret(ib.r(" << instr.result << "))\n"; + break; + } + case Opcode::Goto: { + os << " ib.emit_goto(" << instr.pc_offset << ")\n"; + break; + } + case Opcode::If: { + os << " ib.emit_if(ib.r(" << instr.cond << "), " << instr.false_offset << ")\n"; + break; + } + default: + LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + break; + } + } + } + return String(os.str()); +} + +TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(Executable::LoadFromFile); + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/memory_manager.cc b/src/runtime/relax_vm/memory_manager.cc new file mode 100644 index 0000000000..a017b9c6d9 --- /dev/null +++ b/src/runtime/relax_vm/memory_manager.cc @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/memory_manager.cc + * \brief Allocate and manage memory for the Relay VM. + */ +#include + +#include +#include + +#include "naive_allocator.h" +#include "pooled_allocator.h" + +namespace tvm { +namespace runtime { +namespace relax_vm { + +static void BufferDeleter(Object* obj) { + auto* ptr = static_cast(obj); + ICHECK(ptr->manager_ctx != nullptr); + Buffer* buffer = reinterpret_cast(ptr->manager_ctx); + MemoryManager::GetAllocator(buffer->device)->Free(*(buffer)); + delete buffer; + delete ptr; +} + +void StorageObj::Deleter(Object* obj) { + auto* ptr = static_cast(obj); + // When invoking AllocNDArray we don't own the underlying allocation + // and should not delete the buffer, but instead let it be reclaimed + // by the storage object's destructor. + // + // We did bump the reference count by 1 to keep alive the StorageObj + // allocation in case this NDArray is the sole owner. + // + // We decrement the object allowing for the buffer to release our + // reference count from allocation. + StorageObj* storage = reinterpret_cast(ptr->manager_ctx); + storage->DecRef(); + delete ptr; +} + +inline void VerifyDataType(DLDataType dtype) { + ICHECK_GE(dtype.lanes, 1); + if (dtype.code == kDLFloat) { + ICHECK_EQ(dtype.bits % 8, 0); + } else { + // allow uint1 as a special flag for bool. + if (dtype.bits == 1 && dtype.code == kDLUInt) return; + ICHECK_EQ(dtype.bits % 8, 0); + } + ICHECK_EQ(dtype.bits & (dtype.bits - 1), 0); +} + +inline size_t GetDataAlignment(const DLTensor& arr) { + size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; + if (align < runtime::kAllocAlignment) return runtime::kAllocAlignment; + return align; +} + +runtime::NDArray StorageObj::AllocNDArray(uint64_t offset, ShapeTuple shape, DLDataType dtype) { + VerifyDataType(dtype); + + // critical zone: allocate header, cannot throw + runtime::NDArray::Container* container = + new runtime::NDArray::Container(nullptr, shape, dtype, this->buffer.device); + + container->SetDeleter(StorageObj::Deleter); + size_t needed_size = runtime::GetDataSize(container->dl_tensor); + this->IncRef(); + // The manager context pointer must continue to point to the storage object + // which owns the backing memory, and keeps track of the reference count. + // + // When we free a container we extract the storage object, decrement its + // reference count, then destroy the container, but leave the underlying + // buffer intact. + container->manager_ctx = reinterpret_cast(this); + + // is this UB? + // The only change we make w.r.t offset is modifying the data pointer + // of the backing tensor to point into the buffer instead of its start. + auto offset_ptr = reinterpret_cast(this->buffer.data) + offset; + container->dl_tensor.data = reinterpret_cast(offset_ptr); + + runtime::NDArray ret(runtime::GetObjectPtr(container)); + // RAII in effect, now run the check. + + ICHECK(offset + needed_size <= this->buffer.size) + << "storage allocation failure, attempted to allocate " << needed_size << " at offset " + << offset << " in region that is " << this->buffer.size << "bytes"; + + return ret; +} + +MemoryManager* MemoryManager::Global() { + // NOTE: explicitly use new to avoid exit-time destruction of global state + // Global state will be recycled by OS as the process exits. + static auto* inst = new MemoryManager(); + return inst; +} + +Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) { + MemoryManager* m = MemoryManager::Global(); + std::lock_guard lock(m->mutex_); + if (m->allocators_.find(dev) == m->allocators_.end()) { + std::unique_ptr alloc; + switch (type) { + case kNaive: { + DLOG(INFO) << "New naive allocator for " << runtime::DeviceName(dev.device_type) << "(" + << dev.device_id << ")"; + alloc.reset(new NaiveAllocator(dev)); + break; + } + case kPooled: { + DLOG(INFO) << "New pooled allocator for " << runtime::DeviceName(dev.device_type) << "(" + << dev.device_id << ")"; + alloc.reset(new PooledAllocator(dev)); + break; + } + default: + LOG(FATAL) << "Unknown allocator type: " << type; + } + auto ret = alloc.get(); + m->allocators_.emplace(dev, std::move(alloc)); + return ret; + } + auto alloc = m->allocators_.at(dev).get(); + if (alloc->type() != type) { + LOG(WARNING) << "The type of existing allocator for " << runtime::DeviceName(dev.device_type) + << "(" << dev.device_id << ") is different from the request type (" + << alloc->type() << " vs " << type << ")"; + } + return alloc; +} + +Allocator* MemoryManager::GetAllocator(Device dev) { + MemoryManager* m = MemoryManager::Global(); + std::lock_guard lock(m->mutex_); + auto it = m->allocators_.find(dev); + if (it == m->allocators_.end()) { + LOG(FATAL) << "Allocator for " << runtime::DeviceName(dev.device_type) << "(" << dev.device_id + << ") has not been created yet."; + } + return it->second.get(); +} + +runtime::NDArray Allocator::Empty(std::vector shape, DLDataType dtype, DLDevice dev) { + VerifyDataType(dtype); + runtime::NDArray::Container* container = + new runtime::NDArray::Container(nullptr, shape, dtype, dev); + container->SetDeleter(BufferDeleter); + size_t size = runtime::GetDataSize(container->dl_tensor); + size_t alignment = GetDataAlignment(container->dl_tensor); + Buffer* buffer = new Buffer; + *buffer = this->Alloc(size, alignment, dtype); + container->manager_ctx = reinterpret_cast(buffer); + container->dl_tensor.data = buffer->data; + return runtime::NDArray(runtime::GetObjectPtr(container)); +} + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/relax_vm/naive_allocator.h b/src/runtime/relax_vm/naive_allocator.h new file mode 100644 index 0000000000..843a559602 --- /dev/null +++ b/src/runtime/relax_vm/naive_allocator.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/naive_allocator.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_NAIVE_ALLOCATOR_H_ +#define TVM_RUNTIME_RELAX_VM_NAIVE_ALLOCATOR_H_ + +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +class NaiveAllocator final : public Allocator { + public: + explicit NaiveAllocator(Device dev) : Allocator(kNaive), used_memory_(0), device_(dev) {} + + Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override { + Buffer buf; + buf.device = device_; + buf.size = nbytes; + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, nbytes, alignment, type_hint); + used_memory_.fetch_add(nbytes, std::memory_order_relaxed); + DLOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ << " B"; + return buf; + } + + void Free(const Buffer& buffer) override { + runtime::DeviceAPI::Get(device_)->FreeDataSpace(buffer.device, buffer.data); + used_memory_.fetch_sub(buffer.size, std::memory_order_relaxed); + DLOG(INFO) << "free " << buffer.size << " B, used memory " << used_memory_ << " B"; + } + + private: + std::atomic used_memory_; + Device device_; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_NAIVE_ALLOCATOR_H_ diff --git a/src/runtime/relax_vm/pooled_allocator.h b/src/runtime/relax_vm/pooled_allocator.h new file mode 100644 index 0000000000..0dd7d8b027 --- /dev/null +++ b/src/runtime/relax_vm/pooled_allocator.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/relax_vm/pooled_allocator.h + */ +#ifndef TVM_RUNTIME_RELAX_VM_POOLED_ALLOCATOR_H_ +#define TVM_RUNTIME_RELAX_VM_POOLED_ALLOCATOR_H_ + +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +class PooledAllocator final : public Allocator { + public: + static constexpr size_t kDefaultPageSize = 4096; + + explicit PooledAllocator(Device dev, size_t page_size = kDefaultPageSize) + : Allocator(kPooled), page_size_(page_size), used_memory_(0), device_(dev) {} + + ~PooledAllocator() { ReleaseAll(); } + + Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override { + std::lock_guard lock(mu_); + size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_; + auto&& it = memory_pool_.find(size); + if (it != memory_pool_.end() && !it->second.empty()) { + auto&& pool = it->second; + auto ret = pool.back(); + pool.pop_back(); + return ret; + } + Buffer buf; + buf.device = device_; + buf.size = size; + try { + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } catch (InternalError& err) { + LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.message(); + LOG(WARNING) << "Trying to release all unused memory and reallocate..."; + ReleaseAll(); + buf.data = + runtime::DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); + } + + used_memory_.fetch_add(size, std::memory_order_relaxed); + DLOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B"; + return buf; + } + + void Free(const Buffer& buffer) override { + std::lock_guard lock(mu_); + if (memory_pool_.find(buffer.size) == memory_pool_.end()) { + memory_pool_.emplace(buffer.size, std::vector{}); + } + memory_pool_.at(buffer.size).push_back(buffer); + DLOG(INFO) << "reclaim buffer " << buffer.size; + } + + private: + void ReleaseAll() { + std::lock_guard lock(mu_); + for (auto const& it : memory_pool_) { + auto const& pool = it.second; + for (auto const& buf : pool) { + runtime::DeviceAPI::Get(buf.device)->FreeDataSpace(buf.device, buf.data); + } + } + memory_pool_.clear(); + used_memory_ = 0; + DLOG(INFO) << "release all buffers"; + } + + private: + size_t page_size_; + std::atomic used_memory_; + std::unordered_map > memory_pool_; + std::recursive_mutex mu_; + Device device_; +}; + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_RELAX_VM_POOLED_ALLOCATOR_H_ diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc new file mode 100644 index 0000000000..65656c42c8 --- /dev/null +++ b/src/runtime/relax_vm/vm.cc @@ -0,0 +1,546 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/relax_vm/vm.cc + */ + +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace relax_vm { + +inline TVMRetValue CopyConstantTo(TVMRetValue src, const DLDevice& dev) { + NDArray nd_array = src.operator tvm::runtime::NDArray(); + if (nd_array->device.device_type == dev.device_type && + nd_array->device.device_id == dev.device_id) { + return src; + } + TVMRetValue ret; + ret = nd_array.CopyTo(dev); + return ret; +} + +VMFunction VirtualMachine::LookupVMFunction(const std::string& func_name) { + ICHECK(exec_) << "The executable is not created yet."; + const auto& m = this->exec_->global_map; + if (m.find(func_name) == m.end()) { + LOG(FATAL) << "ValueError: Unknown function: " << func_name; + } + Index gf_idx = m.at(func_name); + const VMFunction& vm_func = exec_->global_funcs[gf_idx]; + return vm_func; +} + +RegType VirtualMachine::LookupVMOutput(const std::string& func_name) { + if (!outputs_.count(func_name)) { + LOG(FATAL) << "ValueError: No output saved for call of \"" << func_name + << "\"; use `invoke_stateful` to call it first."; + } + return outputs_[func_name]; +} + +// Use the args after `starting_arg_idx` as a series of indices into `obj`, +// indexing into nested ADTs and returning the final indexed object. +ObjectRef IndexIntoNestedObject(ObjectRef obj, TVMArgs args, int starting_arg_idx) { + for (int i = starting_arg_idx; i < args.size(); i++) { + // the object must be an ADT to be able to index into it + if (!obj.as()) { + LOG(FATAL) << "ValueError: Attempted to index into an object that is not an ADT."; + } + int index = args[i]; + auto adt = Downcast(obj); + // make sure the index is in bounds + if (index >= static_cast(adt.size())) { + LOG(FATAL) << "IndexError: Invalid index (" << index << " >= " << adt.size() << ")."; + } + obj = adt[index]; + } + return obj; +} + +PackedFunc VirtualMachine::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "vm_initialization") { + // initialize the VirtualMachine, takes variable-length arguments + // first argument is a runtime::Module, followed by one or more device_type, device_id, + // and the AllocatorType associated with the device. + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK_EQ(args.size() % 3, 0); + std::vector devices; + std::vector alloc_types; + for (int i = 0; i < args.size(); i += 3) { + Device dev; + int device_type = args[i]; + dev.device_type = DLDeviceType(device_type); + dev.device_id = args[i + 1]; + int type = args[i + 2]; + devices.push_back(dev); + alloc_types.push_back(AllocatorType(type)); + } + this->Init(devices, alloc_types); + + // Copy NDArray constants to the devices + // TODO(tvm-team): support multiple devices + this->constants.reserve(exec_->constants.size()); + for (const auto& constant : exec_->constants) { + if (constant.type_code() != kTVMNDArrayHandle) { + this->constants.push_back(constant); + } else { + this->constants.push_back(CopyConstantTo(constant, devices[0])); + } + } + }); + } else if (name == "save_function") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + std::string closure_name = args[1]; + bool include_return = args[2]; + const auto& m = exec_->global_map; + if (m.find(func_name) == m.end()) { + LOG(FATAL) << "ValueError: Unknown function: " << func_name; + } + if (m.find(closure_name) != m.end()) { + LOG(FATAL) << "ValueError: Name " << closure_name << " is already taken."; + } + Index gf_idx = m.at(func_name); + std::vector inputs; + if (args.size() > 3) { + inputs = std::vector(args.size() - 3); + for (int i = 3; i < args.size(); i++) { + SetInputTensorWithIndex(inputs, args[i], i - 3, devices[0]); + } + } + if (include_return) { + saved_closures_[closure_name] = + PackedFunc([this, gf_idx, inputs](TVMArgs args, TVMRetValue* rv) { + *rv = this->Invoke(gf_idx, inputs); + }); + } else { + saved_closures_[closure_name] = + PackedFunc([this, gf_idx, inputs](TVMArgs args, TVMRetValue* rv) { + this->Invoke(gf_idx, inputs); + }); + } + }); + } else if (name == "invoke_closure") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + ICHECK(exec_) << "The executable is not created yet."; + VMClosure clo = args[0]; + Array func_args = args[1]; + std::vector new_args; + for (auto f_arg : func_args) { + TVMRetValue arg; + arg = f_arg; + new_args.push_back(arg); + } + // Append the free variables of closure + auto free_vars = clo->free_vars; + for (auto f_var : free_vars) { + TVMRetValue arg; + arg = f_var; + new_args.push_back(arg); + } + + String func_name = clo->func_name; + auto it = exec_->global_map.find(func_name); + ICHECK(it != exec_->global_map.end()) << "No such function " << func_name; + Index func_idx = it->second; + *rv = Invoke(func_idx, new_args); + }); + } else if (name == "invoke_stateful") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + const auto& m = this->exec_->global_map; + if (m.find(func_name) == m.end()) { + LOG(FATAL) << "ValueError: Unknown function: " << func_name; + } + Index gf_idx = m.at(func_name); + if (!inputs_.count(func_name)) { + LOG(FATAL) << "ValueError: No inputs set for stateful call of " << func_name + << "; use `set_input` first."; + return; + } + outputs_[func_name] = this->Invoke(gf_idx, inputs_[func_name]); + }); + } else if (name == "get_output_arity") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + RegType out = LookupVMOutput(func_name); + // use remaining args as indices + ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef(), args, 1); + // after chasing through the indices, examine the final object + if (const auto* adt = obj.as()) { + *rv = static_cast(adt->size); + } else { + *rv = -1; + } + }); + } else if (name == "get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + RegType out = LookupVMOutput(func_name); + // use remaining args as indices + ObjectRef obj = IndexIntoNestedObject(out.AsObjectRef(), args, 1); + if (obj.as()) { + LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC compatibility. " + "Please specify another index argument."; + return; + } + *rv = obj; + }); + } else if (name == "set_input") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); }); + } else if (name == "get_function_arity") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + const VMFunction& vm_func = LookupVMFunction(func_name); + *rv = static_cast(vm_func.param_names.size()); + }); + } else if (name == "get_function_param_name") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + int index = args[1]; + const VMFunction& vm_func = LookupVMFunction(func_name); + if (static_cast(index) >= vm_func.param_names.size()) { + LOG(FATAL) << "ValueError: Invalid index for " << func_name << " (" << index << " out of " + << vm_func.param_names.size() << ")"; + } + *rv = vm_func.param_names[index]; + }); + } + + // check if this is a function we saved + if (saved_closures_.count(name)) { + return saved_closures_[name]; + } + + const auto& m = exec_->global_map; + if (m.find(name) != m.end()) { + Index gf_idx = m.at(name); + return PackedFunc([sptr_to_self, this, gf_idx, name](TVMArgs args, TVMRetValue* rv) { + if (inputs_.count(name)) { + LOG(FATAL) << "ValueError: If inputs have been set, `invoke_stateful`" + << " must be used to invoke a function!"; + return; + } else { + std::vector inputs(args.size()); + for (int i = 0; i < args.size(); ++i) { + inputs[i] = args[i]; + } + *rv = this->Invoke(gf_idx, inputs); + } + }); + } else { + LOG(FATAL) << "ValueError: Unknown function: " << name; + return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); + } +} + +void VirtualMachine::LoadExecutable(ObjectPtr exec) { + this->exec_ = exec; + CHECK_LE(exec_->imports().size(), 1); + this->lib = exec_->imports().empty() ? Optional(NullOpt) : exec_->imports()[0]; +} + +RegType VirtualMachine::Invoke(Index gf_idx, const std::vector& args) { + const VMFunction& gfunc = exec_->global_funcs[gf_idx]; + // Get the curr instr which might be a potential caller. + Instruction curr_instr = exec_->GetInstruction(pc_); + PushFrame(this->pc_, gfunc); + // Get new frame and set the caller info. + VMFrame* curr_frame = frames_.back().get(); + if (curr_instr.op == Opcode::Call) { + curr_frame->caller_return_register = curr_instr.dst; + } + + // load arguments to the register file + ICHECK_EQ(static_cast(gfunc.num_args), args.size()) + << "ValueError: Invoking function " << gfunc.name << " requires " << gfunc.num_args + << " inputs but only " << args.size() << " inputs are provided."; + for (size_t i = 0; i < args.size(); ++i) { + WriteRegister(frames_.back().get(), i, args[i]); + } + // set program counter + pc_ = gfunc.start_instr; + RunLoop(); + return return_value_; +} + +void VirtualMachine::Init(const std::vector& devices, + const std::vector& alloc_types) { + // TODO(@yuchen): support multi-device heterogeneous execution + ICHECK_LT(devices.size(), 3) + << "Currently relax vm only supports at most 2 devices (host + device)"; + ICHECK_EQ(devices.size(), alloc_types.size()); + + this->devices.reserve(devices.size()); + this->allocators.reserve(alloc_types.size()); + for (size_t i = 0; i < devices.size(); i++) { + auto alloc = MemoryManager::GetOrCreateAllocator(devices[i], alloc_types[i]); + this->devices.push_back(devices[i]); + this->allocators.push_back(alloc); + } +} + +void VirtualMachine::PrepareFuncTable(Index func_index) { + // fast path, function already in cache; + + if (static_cast(func_table_.size()) > func_index && func_table_[func_index] != nullptr) + return; + + if (static_cast(func_table_.size()) <= func_index) { + func_table_.resize(func_index + 1, nullptr); + } + + const std::string& func_name = exec_->func_names[func_index]; + + // lookup function and populate + PackedFunc func{nullptr}; + if (this->lib.defined()) { + func = this->lib.value()->GetFunction(func_name, true); + } + if (!func.defined()) { + const PackedFunc* p_func = Registry::Get(func_name); + if (p_func == nullptr) { + const auto& m = exec_->global_map; + ICHECK(m.find(func_name) != m.end()) + << "Error: Cannot find function " << func_name + << " in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in " + "global Relax functions of the VM executable"; + func = this->GetFunction(func_name, GetObjectPtr(this)); + } else { + func = *(p_func); + } + } + func_table_[func_index] = func; +} + +void VirtualMachine::RunInstrCall(VMFrame* curr_frame, Instruction instr) { + DLOG(INFO) << "\n pc = " << pc_ << ", execute: " << exec_->func_names[instr.func_idx]; + + // Use the call arg stack from the current frame to increase reuse + // and avoid re-allocation + curr_frame->call_arg_values.resize(instr.num_args); + curr_frame->call_arg_tcodes.resize(instr.num_args); + + // NOTE: no changes and resize to those vector ref(otherwise can leads to segfault) + // in the remainder part of the function. + std::vector& values = curr_frame->call_arg_values; + std::vector& tcodes = curr_frame->call_arg_tcodes; + + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + for (Index i = 0; i < instr.num_args; ++i) { + Instruction::Arg arg = instr.args[i]; + switch (arg.kind()) { + case Instruction::kRegister: { + if (arg.value() == Instruction::kVMRegister) { + setter(i, this); + } else { + setter(i, ReadRegister(curr_frame, arg.value())); + } + break; + } + case Instruction::kImmediate: { + setter(i, arg.value()); + break; + } + case Instruction::kConstIdx: { + setter(i, this->constants[arg.value()]); + break; + } + default: { + LOG(FATAL) << "ValueError: Unknown argument kind: " << int(arg.kind()); + } + } + } + TVMArgs args(values.data(), tcodes.data(), values.size()); + TVMRetValue ret; + // prepare and invoke + this->PrepareFuncTable(instr.func_idx); + func_table_[instr.func_idx].CallPacked(args, &ret); + + // save the return value to the register + if (instr.dst != Instruction::kVoidArg) { + WriteRegister(curr_frame, instr.dst, ret); + } + // increment pc + pc_++; +} + +int64_t VirtualMachine::LoadScalarInt(RegName reg) const { + int64_t result = 0; + VMFrame* curr_frame = frames_.back().get(); + const RegType& obj = ReadRegister(curr_frame, reg); + NDArray ndarray = obj.operator tvm::runtime::NDArray(); + NDArray ndarray_host = ndarray.CopyTo(devices[0]); + + switch (ndarray_host->dtype.bits) { + case 1: { + result = reinterpret_cast(ndarray_host->data)[0]; + break; + } + case 8: { + result = reinterpret_cast(ndarray_host->data)[0]; + break; + } + case 16: { + result = reinterpret_cast(ndarray_host->data)[0]; + break; + } + case 32: { + result = reinterpret_cast(ndarray_host->data)[0]; + break; + } + case 64: { + result = reinterpret_cast(ndarray_host->data)[0]; + break; + } + default: + LOG(FATAL) << "Unknown scalar int type: " << DLDataType2String(ndarray_host->dtype); + } + return result; +} + +void VirtualMachine::RunLoop() { + VMFrame* curr_frame = frames_.back().get(); + + while (true) { + ICHECK_LT(static_cast(pc_), exec_->instr_offset.size()) << "run into invalide section"; + Instruction instr = exec_->GetInstruction(pc_); + switch (instr.op) { + case Opcode::Call: { + this->RunInstrCall(curr_frame, instr); + break; + } + case Opcode::Ret: { + // If we have hit the point from which we started + // running, we should return to the caller breaking + // the dispatch loop. + return_value_ = ReadRegister(curr_frame, instr.result); + RegName caller_return_register = curr_frame->caller_return_register; + PopFrame(); + if (frames_.size() == 0) { + // directly return if no frame in the call stack. + } else { + // return from a local call. + // Update the current frame to be the parent frame. + curr_frame = frames_.back().get(); + WriteRegister(curr_frame, caller_return_register, return_value_); + } + return; + } + case Opcode::Goto: { + pc_ += instr.pc_offset; + break; + } + case Opcode::If: { + int64_t cond_val = LoadScalarInt(instr.cond); + if (cond_val != 0) { + pc_++; + } else { + ICHECK_GT(instr.false_offset, 1); + pc_ += instr.false_offset; + } + break; + } + } + } +} + +void VirtualMachine::PushFrame(Index ret_pc, const VMFunction& vm_func) { + frames_.emplace_back(std::make_unique(ret_pc, vm_func.register_file_size)); +} + +void VirtualMachine::PopFrame() { + ICHECK_GT(frames_.size(), 0); + pc_ = frames_.back()->return_pc; + frames_.pop_back(); +} + +inline void VirtualMachine::WriteRegister(VMFrame* frame, Index r, const RegType& val) { + frame->register_file[r] = val; +} + +inline RegType VirtualMachine::ReadRegister(VMFrame* frame, Index r) const { + return frame->register_file[r]; +} + +void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) { + const auto& m = exec_->global_map; + if (m.find(func_name) != m.end()) { + Index gf_idx = m.at(func_name); + const VMFunction& vm_func = exec_->global_funcs[gf_idx]; + size_t params_num = vm_func.num_args; + ICHECK_EQ(args.size() - offset, params_num) + << "The number of provided parameters doesn't match the number of arguments for"; + std::vector func_args(params_num); + for (int i = offset; i < args.size(); ++i) { + int index = i - offset; + SetInputTensorWithIndex(func_args, args[i], index, devices[0]); + } + inputs_.emplace(func_name, func_args); + } else { + LOG(FATAL) << "ValueError: Unknown function: " << func_name; + } +} + +inline ObjectRef CopyTo(ObjectRef src, const DLDevice& dev) { + if (src->IsInstance()) { + auto nd_array = Downcast(src); + if (nd_array->device.device_type != dev.device_type || + nd_array->device.device_id != dev.device_id) { + VLOG(2) << "copying from " << nd_array->device.device_type << "[" + << nd_array->device.device_id << "] to " << dev.device_type << "[" << dev.device_id + << "]"; + return nd_array.CopyTo(dev); + } + return src; + } else { + ICHECK(src->IsInstance()) + << "VM data must be NDArray or a list of NDArray, but received: " << src->_type_key; + std::vector ret; + ADT adt = Downcast(src); + for (size_t i = 0; i < adt.size(); i++) { + ret.push_back(CopyTo(adt[i], dev)); + } + return ADT(adt->tag, ret.begin(), ret.end()); + } +} + +void VirtualMachine::SetInputTensorWithIndex(std::vector& func_args, + const TVMArgValue& inp_tensor, int index, Device dev) { + if (inp_tensor.type_code() == kTVMDLTensorHandle) { + if (NDArray::AbilityOfZeroCopyForDLTensor(inp_tensor, dev)) { + func_args[index] = NDArray::FromExternalDLTensor(*inp_tensor); + } else { + func_args[index] = NDArray::NewFromDLTensor(inp_tensor, dev); + } + } else { + func_args[index] = CopyTo(inp_tensor, dev); + } +} + +} // namespace relax_vm +} // namespace runtime +} // namespace tvm diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index a81c56922d..addf129284 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -26,11 +26,15 @@ namespace ir_builder { namespace ir { void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); Map func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); + CHECK_EQ(functions.size(), global_var_map.size()) + << "All functions must be defined in the IRModule. Got " << global_var_map.size() + << "declared function(s), but only " << functions.size() << "defined function(s)."; + for (const auto& kv : functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined"; + func_map.Set(gv, func); } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index a8cc452e4f..de8a7a3b09 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -20,6 +20,8 @@ #include #include +#include "./utils.h" + namespace tvm { namespace script { namespace ir_builder { @@ -27,12 +29,34 @@ namespace ir { IRModuleFrame IRModule() { ObjectPtr n = make_object(); - n->global_vars.clear(); + n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); } +GlobalVar DeclFunction(const String& func_name) { + IRModuleFrame frame = FindModuleFrame("I.DeclFunction"); + CHECK(!frame->global_var_map.count(func_name)) + << "ValueError: function " << func_name << " already exists"; + GlobalVar gv = GlobalVar(func_name); + frame->global_var_map.Set(func_name, gv); + return gv; +} + +void DefFunction(const String& func_name, const BaseFunc& func) { + IRModuleFrame frame = FindModuleFrame("I.DefFunction"); + auto it = frame->global_var_map.find(func_name); + CHECK(it != frame->global_var_map.end()) + << "ValueError: function " << func_name << " does not exist, please declare it first."; + const GlobalVar& gv = (*it).second; + CHECK(frame->functions.find(gv) == frame->functions.end()) + << "ValueError: function " << func_name << " has already been defined."; + frame->functions.Set(gv, func); +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h new file mode 100644 index 0000000000..58d5e53f70 --- /dev/null +++ b/src/script/ir_builder/ir/utils.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace ir { + +inline IRModuleFrame FindModuleFrame(const String& method) { + IRBuilder builder = IRBuilder::Current(); + if (Optional frame = builder->FindFrame()) { + const Optional& last_module_frame = builder->GetLastFrame(); + if (last_module_frame.defined() && last_module_frame.value() == frame) { + return frame.value(); + } + } else { + LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method + << "' is called under I.ir_module()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()"; + throw; +} + +} // namespace ir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc new file mode 100644 index 0000000000..8a7c2ff538 --- /dev/null +++ b/src/script/ir_builder/relax/frame.cc @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +void SeqExprFrameNode::ExitWithScope() { + // At this moment, there should be at most one BlockFrame which hasn't ended. In this case, call + // its `ExitBlockFrame` and check if there is any more unended BlockFrame. + if (Optional block_frame = IRBuilder::Current()->FindFrame()) { + block_frame.value()->ExitWithScope(); + ICHECK(!IRBuilder::Current()->FindFrame().defined()) + << "ValueError: There is some remaining BlockFrame that is not properly popped out."; + } + RelaxFrameNode::ExitWithScope(); +} + +void FunctionFrameNode::ExitWithScope() { + using ir::IRModuleFrame; + using tvm::relax::Expr; + SeqExprFrameNode::ExitWithScope(); + IRBuilder builder = IRBuilder::Current(); + // Step 1: Create the function. + CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " + "`return` to return an Expr"; + output = this->block_builder->Normalize(output.value()); + Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); + tvm::relax::Function func(/*params=*/params, + /*body=*/body, + /*ret_type=*/ret_type.value_or(Type()), + /*ret_shape=*/tvm::relax::RuntimeDepShape(), + /*attrs=*/DictAttrs(attrs)); + // TODO(relax-team): remove this line + func = WithAttr(func, "global_symbol", name.value()); + // Step 2: Update IRModule. + if (builder->frames.empty()) { + // Case 0. No outer frame, return function directly + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = func; + } else if (Optional opt_frame = builder->FindFrame()) { + // Case 1. A global function of an IRModule + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // First time visiting the function. + ir::DeclFunction(func_name); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); + } else { + LOG(FATAL) << "ValueError: Cannot find where to insert Relax.Function"; + } +} + +void BlockFrameNode::EnterWithScope() { + // Step 1. If the last frame is a block frame. The start of a new block frame marks the end of the + // last block frame. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + if (block_frame.defined()) { + block_frame.value()->ExitWithScope(); + // Block frames cannot appear consecutively. + ICHECK(!IRBuilder::Current()->GetLastFrame()); + } + // Step 2. Deal with the new block frame. + RelaxFrameNode::EnterWithScope(); + Optional func_frame = IRBuilder::Current()->FindFrame(); + CHECK(func_frame.defined()) + << "ValueError: Cannot find FunctionFrame when creating BindingBlocks, Please ensure " + "creating the block under Relax function scope."; + const tvm::relax::BlockBuilder& block_builder = func_frame.value()->block_builder; + if (is_dataflow) { + block_builder->BeginDataflowBlock(); + } else { + block_builder->BeginBindingBlock(); + } +} + +void BlockFrameNode::ExitWithScope() { + // Step 1. Pop the current frame out of the frame stack. + RelaxFrameNode::ExitWithScope(); + + // Step 2. Get the constructed binding block from the block builder. The block should have at + // lease one binding - otherwise, the block is not supposed to be created. + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + tvm::relax::BindingBlock block = block_builder->EndBlock(); + CHECK(!block->bindings.empty()) + << "ValueError: A binding block should have at lease one binding."; + + // Step 3. Get the last frame from the IRBuilder frame stack. + Optional opt_last_frame = IRBuilder::Current()->GetLastFrame(); + ICHECK(opt_last_frame.defined()); + RelaxFrame last_frame = opt_last_frame.value(); + + // Step 4. Since we popped out any possible block frame when entering the "with" scope of the + // current frame, the last frame cannot be a block frame. + ICHECK(!last_frame->IsInstance()); + + // Step 5. Push the block frame into the corresponding field of the last frame. + if (const auto* seq_frame = last_frame.as()) { + ICHECK(!seq_frame->output.defined()) + << "The function is not expected to have output values when emitting blocks."; + auto frame = GetRef(seq_frame); + frame->binding_blocks.push_back(block); + } else { + LOG(FATAL) << "ValueError: Currently the last frame is supposed to be either a function frame " + "or a block frame. However, the last frame is \"" + << last_frame->GetTypeKey() << "\"."; + // TODO(ruihang): support IfFrame and then IfFrame is a possible branch here. + } +} + +void IfFrameNode::EnterWithScope() { + const Array& frames = IRBuilder::Current()->frames; + for (const IRBuilderFrame& frame : frames) { + const auto* block_frame = frame.as(); + if (block_frame && block_frame->is_dataflow) { + LOG(FATAL) << "ValueError: Cannot create an IfFrame inside a dataflow block."; + } + } + RelaxFrameNode::EnterWithScope(); +} + +void IfFrameNode::ExitWithScope() { + RelaxFrameNode::ExitWithScope(); + CHECK(then_expr.defined()) + << "ValueError: The body of then part is expected to be defined before exiting."; + CHECK(then_expr.defined()) + << "ValueError: The body of else part is expected to be defined before exiting."; + auto body = tvm::relax::If(condition, then_expr.value(), else_expr.value()); + var = Emit(body, /*is_dataflow=*/false); + IRBuilder::Name(var_name, var); +} + +void ThenFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("R.Then"); + CHECK(!frame->then_expr.defined()) + << "ValueError: Duplicate then branch declaration, previous one is " + << frame->then_expr.value(); + SeqExprFrameNode::EnterWithScope(); +} + +void ThenFrameNode::ExitWithScope() { + SeqExprFrameNode::ExitWithScope(); + String var_name; + output = GetSeqExprForBranch(GetRef(this), &var_name); + IfFrame frame = FindIfFrame("R.Then"); + frame->then_expr = output; + frame->var_name = var_name; +} + +void ElseFrameNode::EnterWithScope() { + IfFrame frame = FindIfFrame("R.Else"); + CHECK(frame->then_expr.defined()) << "The else branch should follow then branch"; + CHECK(!frame->else_expr.defined()) + << "ValueError: Duplicate else branch declaration, previous one is " + << frame->else_expr.value(); + SeqExprFrameNode::EnterWithScope(); +} + +void ElseFrameNode::ExitWithScope() { + SeqExprFrameNode::ExitWithScope(); + String var_name; + output = GetSeqExprForBranch(GetRef(this), &var_name); + IfFrame frame = FindIfFrame("R.Else"); + frame->else_expr = output; + CHECK(frame->var_name == var_name) + << "This last binding of both branches must have the same variable."; +} + +TVM_REGISTER_NODE_TYPE(FunctionFrameNode); +TVM_REGISTER_NODE_TYPE(SeqExprFrameNode); +TVM_REGISTER_NODE_TYPE(BlockFrameNode); +TVM_REGISTER_NODE_TYPE(IfFrameNode); +TVM_REGISTER_NODE_TYPE(ThenFrameNode); +TVM_REGISTER_NODE_TYPE(ElseFrameNode); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc new file mode 100644 index 0000000000..a23ab1e736 --- /dev/null +++ b/src/script/ir_builder/relax/ir.cc @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +///////////////////////////////// Vars ////////////////////////////////// + +using tvm::script::ir_builder::details::Namer; + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using tvm::relax::VarNode; + const VarNode* var = node.as(); + relay::IdNode* vid = const_cast(var->vid.get()); + vid->name_hint = name; + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using tvm::relax::DataflowVarNode; + const DataflowVarNode* var = node.as(); + relay::IdNode* vid = const_cast(var->vid.get()); + vid->name_hint = name; + }); + +////////////////////////////// Tensor Type ////////////////////////////// + +TensorType::TensorType(tvm::relax::DynTensorType type, Optional shape) { + auto n = make_object(); + n->type = std::move(type); + n->shape = std::move(shape); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TensorTypeNode); + +TensorType Tensor(Optional> shape, DataType dtype, int ndim) { + using namespace tvm::relax; + if (shape.defined() && ndim >= 0) { + CHECK_EQ(shape.value().size(), ndim) + << "The dimension of the given shape is mismatched with the given `ndim`"; + } else if (shape.defined()) { + ndim = shape.value().size(); + } + Optional shape_expr = NullOpt; + if (shape.defined()) { + shape_expr = ShapeExpr(shape.value()); + } + return TensorType(DynTensorType(ndim, dtype), shape_expr); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Tensor").set_body_typed(Tensor); + +/////////////////////////////// Function //////////////////////////////// + +FunctionFrame Function() { + ObjectPtr n = make_object(); + n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/NullOpt); + return FunctionFrame(n); +} + +tvm::relax::Var Arg(const String& name, const Type& type, const tvm::relax::ShapeExpr& shape) { + FunctionFrame frame = FindFunctionFrame("R.Arg"); + tvm::relax::Var var(name, shape, type); + frame->params.push_back(var); + return var; +} + +void FuncName(const String& name) { + FunctionFrame frame = FindFunctionFrame("R.func_name"); + if (frame->name.defined()) { + LOG(FATAL) << "ValueError: Duplicate function name, previous one is: \"" << frame->name.value() + << "\""; + } + frame->name = name; +} + +void FuncAttrs(Map attrs) { + FunctionFrame frame = FindFunctionFrame("R.func_attr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate function attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; +} + +void FuncRetType(tvm::Type ret_type) { + FunctionFrame frame = FindFunctionFrame("R.ret_type"); + if (frame->ret_type.defined()) { + LOG(FATAL) << "ValueError: Duplicate function return type, previous one is:\n " + << frame->ret_type.value(); + } + frame->ret_type = ret_type; +} + +void FuncRetValue(const tvm::relax::Expr& value) { + // Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of + // a function body. Therefore if there is any unended block frame when dealing with function + // return, we should end the block frame. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + if (block_frame.defined()) { + block_frame.value()->ExitWithScope(); + ICHECK(!IRBuilder::Current()->FindFrame()) + << "All block frame are supposed to be popped out already"; + } + // Step 2. Add the output value to the function frame. + FunctionFrame frame = FindFunctionFrame("return"); + CHECK(!frame->output.defined()) + << "ValueError: Relax functions don't support multiple return statement. Please make sure " + "the return statement appears at the end of function."; + frame->output = value; +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetType").set_body_typed(FuncRetType); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); + +///////////////////////////// BindingBlock ////////////////////////////// + +BlockFrame Dataflow() { + ObjectPtr n = make_object(); + n->is_dataflow = true; + n->block_ended = false; + return BlockFrame(n); +} + +BlockFrame BindingBlock() { + ObjectPtr n = make_object(); + n->is_dataflow = false; + n->block_ended = false; + return BlockFrame(n); +} + +void DataflowBlockOutput(const Array& vars) { + // Step 1. Check that we're in a Dataflow block that is not ended. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + CHECK(block_frame.defined() && block_frame.value()->is_dataflow) + << "ValueError: `R.output` should appear inside a dataflow block. However, the current " + "innermost block is not a dataflow block."; + CHECK(!block_frame.value()->block_ended) + << "ValueError: It is not allowed for a dataflow block to have multiple output operation."; + + // Step 2. Mark the block frame ended of construction, so that any followup binding after this + // mark in the dataflow block will lead to an error. + block_frame.value()->block_ended = true; + + // Step 3. All the output variables must be global variables and must be emitted by this dataflow + // block. + Array emitted_vars = block_frame.value()->emitted_vars; + for (const tvm::relax::Var& var : vars) { + CHECK(!var->IsInstance()) + << "ValueError: The output variables of a dataflow block must be all global variables."; + CHECK(std::find(emitted_vars.begin(), emitted_vars.end(), var) != emitted_vars.end()) + << "ValueError: An output variable is not emitted by this dataflow block. Please make sure " + "all dataflow block output variables are emitted exactly by this block."; + } + + // Step 4. All normal variables emitted by this dataflow blocks should be output variables. + for (const tvm::relax::Var& emitted_var : emitted_vars) { + if (!emitted_var->IsInstance()) { + CHECK(std::find(vars.begin(), vars.end(), emitted_var) != vars.end()) + << "ValueError: An non-dataflow variable of this dataflow block is not an output " + "variable. Please make sure all non-dataflow variables emitted by this block are all " + "contained in the output variable list."; + } + } +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Dataflow").set_body_typed(Dataflow); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.BindingBlock").set_body_typed(BindingBlock); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.DataflowBlockOutput") + .set_body_typed(DataflowBlockOutput); + +/////////////////////////////// Bindings /////////////////////////////// + +tvm::relax::Var Emit(const tvm::relax::Expr& expr, bool is_dataflow_var) { + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + tvm::relax::Var var{nullptr}; + if (block_frame->is_dataflow && !is_dataflow_var) { + var = block_builder->EmitOutput(expr); + } else { + var = block_builder->Emit(expr); + } + block_frame->emitted_vars.push_back(var); + return var; +} + +Optional EmitMatchShape(const tvm::relax::Expr& value, // + const Array& pattern, // + bool emit_var, // + bool is_dataflow_var) { + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + tvm::relax::BlockBuilder block_builder = GetBlockBuilder(); + + // If we don't intend to emit a variable, just emit the binding and return. + if (!emit_var) { + tvm::relax::MatchShape match_shape(value, pattern, tvm::relax::Var{nullptr}); + block_builder->EmitMatchShape(match_shape); + return NullOpt; + } + + // TODO(tvm-team): Enhance the API of EmitMatchShape in BlockBuilder and then update the following + // code snippet + tvm::relax::Var var{nullptr}; + tvm::relax::Id vid(is_dataflow_var ? "lv" : "gv"); + + if (is_dataflow_var) { + var = tvm::relax::DataflowVar(vid, NullOpt, NullOpt); + } else { + var = tvm::relax::Var(vid, NullOpt, NullOpt); + } + + if (value->checked_type().as()) { + UpdateType(var, tvm::relax::ShapeType()); + } else if (const tvm::relax::DynTensorTypeNode* tty = + value->checked_type().as()) { + tvm::relax::ShapeExpr shape = tvm::relax::ShapeExpr(pattern); + UpdateShape(var, shape); + DataType dtype = tty->dtype; + UpdateType(var, tvm::relax::DynTensorType(pattern.size(), dtype)); + } else { + LOG(FATAL) << "The value passed to EmitMatchShape must be of DynTensorType or ShapeType."; + } + + block_frame->emitted_vars.push_back(var); + return block_builder->EmitMatchShape(tvm::relax::MatchShape(value, pattern, var)); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchShape").set_body_typed(EmitMatchShape); + +///////////////////////////// Type Deduce ////////////////////////////// + +void AnnotateTypeShape(const tvm::relax::Var& var, const Type& anno_type, + const Optional& anno_shape) { + using tvm::relax::IsBaseOf; + if (var->checked_type_.defined()) { + const Type& var_type = var->checked_type(); + CHECK(IsBaseOf(anno_type, var_type) || IsBaseOf(var_type, anno_type)) + << "TypeError: The annotated type and value type are not compatible. " + << "The Type is expected to be " << var_type << " but got annotation: " << anno_type; + } + + if (var->shape_.defined() && anno_shape.defined()) { + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + tvm::relax::Expr var_shape = Downcast(var->shape_.value()); + CHECK(block_builder->CanProveShapeEqual(var_shape, anno_shape.value())) + << " The shape of var " << var->name_hint() << " is expected to be " << var_shape + << " but got annotation: " << anno_shape.value(); + } + + var->checked_type_ = anno_type; + var->shape_ = anno_shape; +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.AnnotateTypeShape").set_body_typed(AnnotateTypeShape); + +///////////////////////////// If Then Else ///////////////////////////// + +IfFrame If(tvm::relax::Expr condition) { + ObjectPtr n = make_object(); + n->condition = condition; + n->then_expr = NullOpt; + n->else_expr = NullOpt; + return IfFrame(n); +} + +ThenFrame Then() { + ObjectPtr n = make_object(); + return ThenFrame(n); +} + +ElseFrame Else() { + ObjectPtr n = make_object(); + return ElseFrame(n); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h new file mode 100644 index 0000000000..e55957cdbf --- /dev/null +++ b/src/script/ir_builder/relax/utils.h @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ + +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +inline FunctionFrame FindFunctionFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { + return frame.value(); + } + LOG(FATAL) << "ValueError: Function frame not find. Please ensure '" << method + << "' is called under R.function()"; + throw; +} + +inline IfFrame FindIfFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + return frame.value(); + } else { + LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method + << "' is called under R.if_()"; + } + throw; +} + +inline tvm::relax::BlockBuilder GetBlockBuilder() { + Optional frame = IRBuilder::Current()->FindFrame(); + CHECK(frame.defined()) << "ValueError: Relax Function frame not find. Please ensure " + "assignment is called under R.function()"; + return frame.value()->block_builder; +} + +inline BlockFrame CheckBlockFrameExistAndUnended() { + // - If we're emitting a non-dataflow binding in the function (that is to say, the binding is not + // wrapped by `with R.dataflow()`), it is possible that there is no existing BlockFrame. In this + // case, we will create a BlockFrame and "enter its 'with' scope" first. + // - Otherwise, there is already an existing BlockFrame. We check if the block is "ended" - if a + // block is ended, it is not allowed to emit new bindings into this block, and we should throw + // exceptions. + + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + if (block_frame.defined()) { + CHECK(!block_frame.value()->block_ended) + << "ValueError: New binding is not allowed after dataflow block output."; + return block_frame.value(); + } + + BlockFrame new_block_frame = BindingBlock(); + new_block_frame->EnterWithScope(); + return new_block_frame; +} + +inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String* var_name) { + // Step 0. Check frame type + std::string method; + if (frame->IsInstance()) { + method = "R.Then"; + } else if (frame->IsInstance()) { + method = "R.Else"; + } else { + ICHECK(false) << "TypeError: Unsupported frame type: " << frame->GetTypeKey(); + } + + // Step 1. Check non-empty block and last binding is non-dataflow + CHECK(!frame->binding_blocks.empty()) + << "Empty body is not allowed for '" << method << "' statements."; + const tvm::relax::BindingBlock& last_block = frame->binding_blocks.back(); + CHECK(!last_block->bindings.empty()) << "Blocks are expected to be non-empty."; + + // Step 2. Collect body from the last binding. + tvm::relax::Expr body; + const tvm::relax::Binding& last_binding = last_block->bindings.back(); + if (const auto* var_binding = last_binding.as()) { + CHECK(!var_binding->var->IsInstance()) + << "A non-dataflow var is expected in the last binding of '" << method << "'."; + body = var_binding->value; + *var_name = var_binding->var->name_hint(); + } else if (const auto* match_shape = last_binding.as()) { + CHECK(match_shape->var.defined() && + !match_shape->var->IsInstance()) + << "A non-dataflow var is expected in the last binding of '" << method << "'."; + body = var_binding->value; + *var_name = match_shape->var->name_hint(); + } else { + ICHECK(false) << "TypeError: Unsupported binding type: " << last_binding->GetTypeKey(); + } + + // Step 3. Re-collect binding blocks to remove the last binding. + Array new_blocks(frame->binding_blocks.begin(), + frame->binding_blocks.end() - 1); + Array last_block_bindings(last_block->bindings.begin(), + last_block->bindings.end() - 1); + new_blocks.push_back(tvm::relax::BindingBlock(last_block_bindings)); + + return tvm::relax::SeqExpr(new_blocks, body); +} + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index aa9efa653f..315b3edd46 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -42,9 +43,17 @@ void PrimFuncFrameNode::ExitWithScope() { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { - ir::IRModuleFrame frame = opt_frame.value(); - frame->global_vars.push_back(GlobalVar(name.value_or(""))); - frame->functions.push_back(func); + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const ir::IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // Case. First time visiting the function. + ir::DeclFunction(func_name); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); } else { LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; } diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index c222de81f2..ca43a93295 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -159,7 +159,6 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); info->tensor2buffers[tensor] = buffer; buffers.push_back(buffer); - if (!info->IsArg(tensor)) { info->root_alloc.push_back(info->tensor2buffers[tensor]); } @@ -287,6 +286,12 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, /*annotations=*/std::move(annotations))); } +inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { + return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) && + ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init))); +} + Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info, arith::Analyzer* analyzer) { // Step 1. Creating loop vars for block bindings. @@ -457,7 +462,8 @@ void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Array& arg_list, - const Array& root_stmts, CreateFuncInfo* info) { + const Array& root_stmts, CreateFuncInfo* info, + const Optional> tir_var_list) { Array parameters; Map buffer_map; for (const te::Tensor& tensor : arg_list) { @@ -467,11 +473,20 @@ PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, ICHECK(it != info->tensor2buffers.end()); buffer_map.Set(arg, it->second); } + + // add additional arguments for tir vars that are left unbound by match buffer + if (tir_var_list) { + for (const Var& v : tir_var_list.value()) { + parameters.push_back(v); + } + } + PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters), /*body=*/SeqStmt::Flatten(root_stmts), /*ret_type=*/VoidType(), /*buffer_map=*/std::move(buffer_map)), {{"global_symbol", String("main")}, {"tir.noalias", Bool(true)}}); + const auto* complete = runtime::Registry::Get("script.Complete"); ICHECK(complete); func = (*complete)(std::move(func), info->root_alloc); @@ -479,7 +494,8 @@ PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, } PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants) { + const Array& constants, + const Optional>& tir_var_list) { // Infomations used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(arg_list); // Root body stmts. @@ -497,15 +513,14 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, for (const te::Operation& op : order) { RewriteStageToBlock(op, &info, &root_stmts, &analyzer); } - - // Step 4. Create func and complete prim func. - auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info); + auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info, tir_var_list); func = tir::BindParams(func, constants); return LayoutFreePlaceholdersNormalizer().Process(std::move(func)); } -PrimFunc CreatePrimFunc(const Array& arg_list) { - return CreatePrimFuncWithConstants(arg_list, {}); +PrimFunc CreatePrimFunc(const Array& arg_list, + const Optional> tir_var_list) { + return CreatePrimFuncWithConstants(arg_list, {}, tir_var_list); } TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index b68d30a2fb..4f72824ccc 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -28,7 +28,8 @@ namespace tvm { namespace tir { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ -PrimFunc CreatePrimFunc(const Array& arg_list); +PrimFunc CreatePrimFunc(const Array& arg_list, + const Optional> tir_var_list); /*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the * constants array is N, the last N tensors in arg_list will be treated as constant tensors. @@ -36,7 +37,8 @@ PrimFunc CreatePrimFunc(const Array& arg_list); * will be embedded in the body as AllocateConstNode. */ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, - const Array& constants); + const Array& constants, + const Optional>& tir_var_list); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 4c59a17673..781a0ecd7c 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -115,8 +115,8 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, String name, tvm::Array required) { - PassInfo pass_info = PassInfo(opt_level, name, required); + int opt_level, String name, tvm::Array required, bool traceable) { + PassInfo pass_info = PassInfo(opt_level, name, required, traceable); return PrimFuncPass(pass_func, pass_info); } diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index 2680589457..336c45df2c 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -615,7 +615,10 @@ BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) { [&srch](const auto& c) { return srch.end() == srch.find(c); }); buf->conflicts.Assign(conflicts.begin(), conflicts.end()); } - return BufferInfoAnalysis(this->buffer_info_map_, max_open_set_size); + + Map buffer_info_map_ref = + Map(this->buffer_info_map_.begin(), this->buffer_info_map_.end()); + return BufferInfoAnalysis(buffer_info_map_ref, max_open_set_size); } BufferInfoAnalysis ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) { diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index 3350ecc5d4..9ca2515bd1 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -77,7 +77,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ",\n conflicts=" << node->conflicts.size() << ")"; }); -BufferInfoAnalysis::BufferInfoAnalysis(Map buffer_info_stmts, +BufferInfoAnalysis::BufferInfoAnalysis(Map buffer_info_stmts, Integer memory_pressure) { auto bufinfo_analysis_node = make_object(); bufinfo_analysis_node->buffer_info_stmts = buffer_info_stmts; @@ -87,7 +87,8 @@ BufferInfoAnalysis::BufferInfoAnalysis(Map buffer_info_st TVM_REGISTER_NODE_TYPE(BufferInfoAnalysisNode); TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoAnalysis") - .set_body_typed([](Map buffer_info_stmts, Integer memory_pressure) { + .set_body_typed([](Map buffer_info_stmts, + Integer memory_pressure) { return BufferInfoAnalysis(buffer_info_stmts, memory_pressure); }); @@ -145,7 +146,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -Array ConvertToArrayOfBufferInfo(const Map& buffer_info_map) { +Array ConvertToArrayOfBufferInfo(const Map& buffer_info_map) { Array ret; for (const auto& kv : buffer_info_map) { auto buffer_info = kv.first; @@ -155,12 +156,12 @@ Array ConvertToArrayOfBufferInfo(const Map& buffer } Map AssignStmtPoolAllocations( - const Map& buffer_info_to_stmt, + const Map& buffer_info_to_stmt, const Map& buffer_info_to_pool_allocation) { Map ret; for (const auto& kv : buffer_info_to_pool_allocation) { BufferInfo bi = kv.first; - Stmt stmt_ = buffer_info_to_stmt[bi]; + Stmt stmt_ = runtime::Downcast(buffer_info_to_stmt[bi]); PoolAllocation pa = kv.second; ret.Set(stmt_, pa); } @@ -265,7 +266,7 @@ Integer CalculateModuleWorkspaceSize(const IRModule& mod) { } TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo") - .set_body_typed([](Map buffer_info_map) { + .set_body_typed([](Map buffer_info_map) { return (ConvertToArrayOfBufferInfo(buffer_info_map)); }); diff --git a/tests/lint/cpplint.sh b/tests/lint/cpplint.sh index 6c01f0eb0a..5123c41e62 100755 --- a/tests/lint/cpplint.sh +++ b/tests/lint/cpplint.sh @@ -18,12 +18,5 @@ set -e -echo "Running 2 cpplints (VTA and TVM)..." -python3 3rdparty/dmlc-core/scripts/lint.py --quiet vta cpp vta/include vta/src -python3 3rdparty/dmlc-core/scripts/lint.py --quiet tvm cpp \ - include src \ - examples/extension/src examples/graph_executor/src \ - tests/cpp tests/crt \ - --exclude_path "src/runtime/hexagon/rpc/hexagon_rpc.h" \ - "src/runtime/hexagon/rpc/hexagon_rpc_skel.c" \ - "src/runtime/hexagon/rpc/hexagon_rpc_stub.c" +python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp \ + include/tvm/relax src/relax/ diff --git a/tests/python/contrib/test_hexagon/test_relax_integration.py b/tests/python/contrib/test_hexagon/test_relax_integration.py new file mode 100644 index 0000000000..1947cc7eae --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_relax_integration.py @@ -0,0 +1,305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest +import tvm.testing +from tvm import relay, relax, runtime +from tvm.relax.testing import relay_translator +from tvm.contrib.hexagon.session import Session +from tvm.script import relax as R, tir as T +from tvm.relay import testing + + +@tvm.testing.requires_hexagon +def test_conv2d(hexagon_session: Session): + dtype = "float32" + data = relay.var("data", relay.TensorType((1, 64, 64, 3), dtype)) + weight = relay.var("weight", relay.TensorType((5, 5, 3, 8), dtype)) + y = relay.nn.conv2d( + data, + weight, + padding=(2, 2), + kernel_size=(5, 5), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="float32", + ) + f = relay.Function([data, weight], y) + relay_mod = tvm.IRModule.from_expr(f) + + # target_hexagon = "llvm -keys=hexagon -link-params=0 -mattr=+hvxv69,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp -mcpu=hexagonv69 -mtriple=hexagon" + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + relax_mod = relay_translator.from_relay(relay_mod["main"], target) + + ex = relax.vm.build(relax_mod, target) + dev = hexagon_session.device + vm_mod = hexagon_session.get_executor_from_factory(ex) + vm_rt = relax.VirtualMachine(vm_mod, dev) + + data_np = np.random.rand(1, 64, 64, 3).astype(np.float32) + weight_np = np.random.rand(5, 5, 3, 8).astype(np.float32) + + # Run on hexagon and get result + data = tvm.nd.array(data_np, dev) + weight = tvm.nd.array(weight_np, dev) + vm_rt.set_input("main", data, weight) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on Relay for comparison. + dev = tvm.cpu() + data = tvm.nd.array(data_np, dev) + weight = tvm.nd.array(weight_np, dev) + + target = tvm.target.Target("llvm", host="llvm") + vm_exec = relay.vm.compile(relay_mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu()) + relay_res = vm_factory.invoke("main", data, weight) + tvm.testing.assert_allclose(hexagon_res.numpy(), relay_res.numpy(), rtol=1e-3) + + +@tvm.testing.requires_hexagon +def test_conv2d_dyn(hexagon_session: Session): + dtype = "float32" + data = relay.var("data", relay.TensorType((relay.Any(), 64, 64, 3), dtype)) + weight = relay.var("weight", relay.TensorType((5, 5, 3, 8), dtype)) + y = relay.nn.conv2d( + data, + weight, + padding=(2, 2), + kernel_size=(5, 5), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="float32", + ) + f = relay.Function([data, weight], y) + relay_mod = tvm.IRModule.from_expr(f) + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + relax_mod = relay_translator.from_relay(relay_mod["main"], target) + + ex = relax.vm.build(relax_mod, target) + hexgaon_device = hexagon_session.device + vm_mod = hexagon_session.get_executor_from_factory(ex) + vm_rt = relax.VirtualMachine(vm_mod, hexgaon_device) + + data_np = np.random.rand(1, 64, 64, 3).astype(np.float32) + weight_np = np.random.rand(5, 5, 3, 8).astype(np.float32) + + # Run on hexagon and get result + data = tvm.nd.array(data_np, hexgaon_device) + weight = tvm.nd.array(weight_np, hexgaon_device) + vm_rt.set_input("main", data, weight) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on Relay for comparison. + cpu_device = tvm.cpu() + data = tvm.nd.array(data_np, cpu_device) + weight = tvm.nd.array(weight_np, cpu_device) + + target = tvm.target.Target("llvm", host="llvm") + vm_exec = relay.vm.compile(relay_mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu()) + relay_res = vm_factory.invoke("main", data, weight) + tvm.testing.assert_allclose(hexagon_res.numpy(), relay_res.numpy(), rtol=1e-3) + + +@tvm.testing.requires_hexagon +def test_mlp(hexagon_session: Session): + relay_mod, params = testing.mlp.get_workload(batch_size=1, dtype="float32") + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) + + ex = relax.vm.build(relax_mod, target) + hexagon_device = hexagon_session.device + + vm_mod = hexagon_session.get_executor_from_factory(ex) + vm_rt = relax.VirtualMachine(vm_mod, hexagon_device) + + shape = (1, 1, 28, 28) + data_np = np.random.rand(*shape).astype("float32") + data = tvm.nd.array(data_np, hexagon_device) + vm_rt.set_input("main", data) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on Relay for comparison. + cpu_dev = tvm.cpu() + data = tvm.nd.array(data_np, cpu_dev) + + target = tvm.target.Target("llvm", host="llvm") + vm_exec = relay.vm.compile(relay_mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, cpu_dev) + relay_res = vm_factory.invoke("main", data, **params) + tvm.testing.assert_allclose(hexagon_res.numpy(), relay_res.numpy(), rtol=1e-3) + + +@tvm.testing.requires_hexagon +def test_mlp_dyn(hexagon_session: Session): + relay_mod, params = testing.mlp.get_workload(batch_size=relay.Any(), dtype="float32") + shape = (1, 1, 28, 28) + data_np = np.random.rand(*shape).astype("float32") + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + # translate the relay mobilenet and bind params + relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) + + # Compile and run on Hexagon. + ex = relax.vm.build(relax_mod, target) + dev = hexagon_session.device + + vm_mod = hexagon_session.get_executor_from_factory(ex) + vm_rt = relax.VirtualMachine(vm_mod, dev) + data = tvm.nd.array(data_np, dev) + vm_rt.set_input("main", data) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on Relay for comparison. + dev = tvm.cpu() + data = tvm.nd.array(data_np, dev) + + target = tvm.target.Target("llvm", host="llvm") + vm_exec = relay.vm.compile(relay_mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, dev) + relay_res = vm_factory.invoke("main", data, **params) + print(hexagon_res) + print(relay_res) + tvm.testing.assert_allclose(hexagon_res.numpy(), relay_res.numpy(), rtol=1e-3) + + +def get_onnx_mobilenet(): + """Download and import mobilenet model with ONNX""" + import onnx # pylint: disable=import-outside-toplevel + + model_url = "https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx" + model_path = tvm.contrib.download.download_testdata( + model_url, "mobilenetv2-7.onnx", module="onnx" + ) + return onnx.load(model_path) + + +@pytest.mark.skip("takes too long (~20min)") +@tvm.testing.requires_hexagon +def test_mobilenet_onnx(hexagon_session: Session): + onnx_model = get_onnx_mobilenet() + data_np = np.random.rand(1, 3, 224, 224).astype("float32") + shape_dict = {"input": data_np.shape} + relay_mod, _ = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True) + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + relax_mod = relay_translator.from_relay(relay_mod["main"], target_hexagon) + + # Compile and run on Hexagon. + ex = relax.vm.build(relax_mod, target) + dev = hexagon_session.device + + vm_mod = hexagon_session.get_executor_from_factory(ex) + vm_rt = relax.VirtualMachine(vm_mod, dev) + data = tvm.nd.array(data_np, dev) + vm_rt.set_input("main", data) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on LLVM for comparison. + relax_mod = relay_translator.from_relay(relay_mod["main"], "llvm") + ex = relax.vm.build(relax_mod, "llvm") + dev = tvm.cpu() + vm_rt = relax.VirtualMachine(ex, dev) + data = tvm.nd.array(data_np, dev) + llvm_res = vm_rt["main"](data) + tvm.testing.assert_allclose(hexagon_res.numpy(), llvm_res.numpy(), rtol=1e-3) + + +@pytest.mark.skip("takes too long (~20min)") +@tvm.testing.requires_hexagon +def test_mobilenet(hexagon_session: Session): + relay_mod, params = testing.mobilenet.get_workload(batch_size=1, dtype="float32") + data_np = np.random.rand(1, 3, 224, 224).astype("float32") + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + # translate the relay mobilenet and bind params + relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) + + # Compile and run on Hexagon. + ex = relax.vm.build(relax_mod, target) + dev = hexagon_session.device + + vm_mod = hexagon_session.get_executor_from_factory(ex) + vm_rt = relax.VirtualMachine(vm_mod, dev) + data = tvm.nd.array(data_np, dev) + vm_rt.set_input("main", data) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on LLVM for comparison. + relax_mod = relay_translator.from_relay(relay_mod["main"], "llvm", params) + ex = relax.vm.build(relax_mod, "llvm") + dev = tvm.cpu() + vm_rt = relax.VirtualMachine(ex, dev) + data = tvm.nd.array(data_np, dev) + llvm_res = vm_rt["main"](data) + tvm.testing.assert_allclose(hexagon_res.numpy(), llvm_res.numpy(), rtol=1e-3) + + +@pytest.mark.skip("takes too long (~20min)") +@tvm.testing.requires_hexagon +def test_mobilenet_dyn(hexagon_session: Session): + relay_mod, params = testing.mobilenet.get_workload(batch_size=relay.Any(), dtype="float32") + data_np = np.random.rand(1, 3, 224, 224).astype("float32") + + target_hexagon = tvm.target.hexagon("v68") + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + # translate the relay mobilenet and bind params + relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) + + # Compile and run on Hexagon. + ex = relax.vm.build(relax_mod, target) + dev = hexagon_session.device + + vm_mod = hexagon_session.get_executor_from_factory(ex) + vm_rt = relax.VirtualMachine(vm_mod, dev) + data = tvm.nd.array(data_np, dev) + vm_rt.set_input("main", data) + vm_rt.invoke_stateful("main") + hexagon_res = vm_rt.get_outputs("main") + + # Compile and run on Relay for comparison. + dev = tvm.cpu() + data = tvm.nd.array(data_np, dev) + + target = tvm.target.Target("llvm", host="llvm") + vm_exec = relay.vm.compile(relay_mod, target=target) + vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu()) + relay_res = vm_factory.invoke("main", data, **params) + tvm.testing.assert_allclose(hexagon_res.numpy(), relay_res.numpy(), rtol=1e-3) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/integration/test_relax_rpc_tuning.py b/tests/python/integration/test_relax_rpc_tuning.py new file mode 100644 index 0000000000..45594d3b63 --- /dev/null +++ b/tests/python/integration/test_relax_rpc_tuning.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test tuning a model in Relax over RPC, end-to-end.""" +from __future__ import annotations +import os +import subprocess +import time +from typing import Callable, Any + +import tvm +from tvm import rpc +from tvm.rpc.tracker import Tracker +from tvm.contrib import utils +import tvm.testing + + +def retry_with_backoff(thunk: Callable[[], Any]) -> Any: + """ + Calls the thunk and, if it fails (raises an exception), tries again after a 1s backoff. + """ + try: + return thunk() + except: # pylint: disable=bare-except + time.sleep(1.0) + return thunk() + + +def check_connection(host: str, port: int, key: str) -> bool: + """ + Returns true if the tracker at host:port has any servers under the given key + """ + # Timeout is set to 5 because the script tries again every 5s if it fails; + # we will only permit it one try. + # (We could use `rpc.connect_tracker` directly but it retries indefinitely.) + check = subprocess.check_output( + [ + "python3", + "-m", + "tvm.exec.query_rpc_tracker", + "--host", + host, + "--port", + str(port), + ], + timeout=5, + ) + # if the key isn't in the printed message, then they didn't connect + return key in str(check) + + +def connect_server(host: str, port: int, key: str) -> rpc.Server: + """ + Starts a server and attempts to connect it to a tracker + at the given host and port with the given key. + + Subsequently checks if the connection succeeded. + """ + server = rpc.Server(host=host, key=key, tracker_addr=(host, port)) + # retry in case we check before the connection comes in + if not retry_with_backoff(lambda: check_connection(host, port, key)): + raise Exception("Failed to connect") + return server + + +@tvm.testing.slow +def test_relax_auto_tir_e2e_rpc(): + """ + Run the e2e_auto_tir Relax example script over RPC on localhost. + """ + rpc_host = "127.0.0.1" + rpc_key = "Test1" + rpc_port = 5555 + + # if we don't bind tracker and server to variables, they are deleted and closed + tracker = Tracker(host=rpc_host, port=rpc_port) # pylint: disable=unused-variable + # retry in case the server tries to connect before the tracker starts + server = retry_with_backoff( # pylint: disable=unused-variable + lambda: connect_server(rpc_host, rpc_port, rpc_key) + ) + + tuning_dir = utils.tempdir() + run_script = subprocess.run( + [ + "python3", + os.path.join(os.environ["TVM_HOME"], "apps", "relax_examples", "e2e_auto_tir.py"), + "--workload", + "resnet_50", + "--target", + # metascheduler requires specifying the number of cores; + # this uses 16 because that is what is used in the other tuning tests + "llvm -num-cores 16", + "--input-shape", + "[1, 3, 224, 224]", + # 0 trials so there is no tuning, just testing + "--num-trials", + "0", + "--rpc-host", + rpc_host, + "--rpc-port", + str(rpc_port), + "--rpc-key", + rpc_key, + "--work-dir", + tuning_dir.path, + # this can take several minutes and the default timeout is seldom enough + "--rpc-timeout-sec", + "600", + ], + check=False, + capture_output=True, + ) + # just checking that it completes successfully + assert run_script.returncode == 0, (run_script.stdout, run_script.stderr) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/conftest.py b/tests/python/relax/conftest.py new file mode 100644 index 0000000000..f1b1187066 --- /dev/null +++ b/tests/python/relax/conftest.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations + +import pytest + +import tvm +from tvm.relax.ir.instrument import WellFormedInstrument + + +tvm.transform.PassContext.current().override_instruments([WellFormedInstrument()]) diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py new file mode 100644 index 0000000000..125dbdcf66 --- /dev/null +++ b/tests/python/relax/test_analysis.py @@ -0,0 +1,413 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +from typing import List, Set, Union +import pytest + +import tvm +from tvm import tir +from tvm import relax as rx +from tvm.relax.analysis import ( + udchain, + remove_all_unused, + name_to_binding, + shape_vars, + derive_func_ret_shape, + all_vars, + free_vars, + bound_vars, + all_global_vars, + called_global_vars, +) +from tvm.script import relax as R + + +def var_name_set(vars: List[Union[rx.Var, rx.GlobalVar]]) -> Set[str]: + return set(map(lambda v: v.name_hint, vars)) + + +def test_dispatch_var(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + v0 = rx.Var("v0", [m, n], type_anno0) + v1 = rx.DataflowVar("v1", [n], type_anno1) + t = None + + def fvisit(e): + nonlocal t + t = type(e) + + rx.analysis.post_order_visit(v0, fvisit) + assert t == type(v0) + rx.analysis.post_order_visit(v1, fvisit) + assert t == type(v1) + + +def test_post_order_visit(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + ib = rx.BlockBuilder() + with ib.function("func", [x, y]): + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.add(x, y)) + lv1 = ib.emit(rx.op.multiply(lv0, y)) + gv0 = ib.emit_output(lv1) + ib.emit_func_output(gv0) + expr = ib.get()["func"] + + names = [] + + def fvisit(e): + nonlocal names + if isinstance(e, tvm.ir.op.Op): + names.append(e.name) + + rx.analysis.post_order_visit(expr.body, fvisit) + assert names == ["relax.add", "relax.multiply"] + + +def test_use_def(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + ib = rx.BlockBuilder() + with ib.function("func", [x, y]): + with ib.dataflow(): + lv0 = ib.emit(rx.op.add(x, y)) + lv1 = ib.emit(rx.op.multiply(lv0, y)) + gv0 = ib.emit_output(lv1) + ib.emit_func_output(gv0) + dfb = ib.get()["func"].body.blocks[0] + udc = udchain(dfb) + assert set(udc[x]) == {lv0} + assert set(udc[y]) == {lv0, lv1} + assert set(udc[lv0]) == {lv1} + assert set(udc[lv1]) == {gv0} + assert set(udc[gv0]) == set() + + +def test_chained_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir(my_sigmoid, (x,), (32, 32), dtype="float32") + unused1 = R.call_tir(my_sigmoid, (unused0,), (32, 32), dtype="float32") + R.output(lv0) + return lv0 + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_binding_block_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir(my_sigmoid, (x,), (32, 32), dtype="float32") + unused1 = R.call_tir(my_sigmoid, (unused0,), (32, 32), dtype="float32") + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, type_args=(Tensor((32, 32), "float32"))) + return z + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, type_args=(Tensor((32, 32), "float32"))) + return z + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_binding_block_fake_unused_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + z = R.call_packed("vm.builtin.copy", lv0, type_args=(Tensor((32, 32), "float32"))) + return lv0 + + optimized = remove_all_unused(IdentityUnused["main"]) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + # This might bring side effect so cannot be removed. + z = R.call_packed("vm.builtin.copy", lv0, type_args=(Tensor((32, 32), "float32"))) + return lv0 + + tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + + +def test_edge_binding_block_fake_unused_remove_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor((32, 32), "float32"): + z = R.call_packed("vm.builtin.copy", x, type_args=(Tensor((32, 32), "float32"))) + return x + + optimized = remove_all_unused(IdentityUnused["main"]) + tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"]) + + +def test_name_to_binding_var_shadowing(): + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + lv1 = lv0 + R.output(lv1) + + with R.dataflow(): + lv0 = lv1 # shadowing + lv2 = lv0 + R.output(lv2) + return lv2 + + n2binding = name_to_binding(main) + + assert "lv0" in n2binding + assert "lv1" in n2binding + assert "lv2" in n2binding + + assert len(n2binding["lv0"]) == 2 + + +def test_shape_var_shape_expr(): + v1 = tir.Var("v1", "int64") + v2 = tir.Var("v2", "int64") + v3 = tir.Var("v3", "int64") + shape_expr = rx.ShapeExpr([v1, v2, tir.Add(v3, v1)]) + vars = shape_vars(shape_expr) + + assert len(vars) == 3 + assert v1 in vars + assert v2 in vars + assert v3 in vars + + shape_expr = rx.ShapeExpr([tir.const(1), tir.const(2)]) + vars = shape_vars(shape_expr) + assert len(vars) == 0 + + +def test_shape_var_nested(): + v1 = rx.Var("v1") + v2 = rx.Var("v2") + sv1 = tir.Var("sv1", "int64") + shape_expr = rx.ShapeExpr([sv1]) + tup = rx.Tuple([v1, v2, shape_expr]) + vars = shape_vars(tup) + + assert len(vars) == 1 + assert sv1 in vars + + x = rx.Var("x", type_annotation=rx.DynTensorType(ndim=-1, dtype="int64")) + y = rx.Var("y", type_annotation=rx.DynTensorType(ndim=-1, dtype="int64")) + + func = rx.Function([x, y], shape_expr, rx.ShapeType(), rx.RuntimeDepShape()) + vars = shape_vars(func) + + assert len(vars) == 1 + assert sv1 in vars + + +def test_derive_func_ret_shape_no_free(): + sv1 = tir.Var("sv1", "int64") + sv2 = tir.Var("sv2", "int64") + sv3 = tir.Var("sv3", "int64") + a1 = rx.Var( + "a1", type_annotation=rx.DynTensorType(ndim=2), shape_annotation=rx.ShapeExpr([sv1, sv2]) + ) + a2 = rx.Var( + "a2", type_annotation=rx.DynTensorType(ndim=2), shape_annotation=rx.ShapeExpr([sv2, sv3]) + ) + body = a2 + shape_expr = derive_func_ret_shape([a1, a2], body) + + assert isinstance(shape_expr, rx.ShapeExpr) + assert shape_expr[0] == sv2 + assert shape_expr[1] == sv3 + + +def test_derive_func_ret_shape_free(): + sv1 = tir.Var("sv1", "int64") + sv2 = tir.Var("sv2", "int64") + sv3 = tir.Var("sv3", "int64") + a1 = rx.Var( + "a1", type_annotation=rx.DynTensorType(ndim=2), shape_annotation=rx.ShapeExpr([sv1, sv2]) + ) + a2 = rx.Var( + "a2", type_annotation=rx.DynTensorType(ndim=2), shape_annotation=rx.ShapeExpr([sv2, sv1]) + ) + # Artifically introducing a free shape variable. + # This would not be a valid program, but this is being done to test the logic + body = rx.Var( + "a3", type_annotation=rx.DynTensorType(ndim=2), shape_annotation=rx.ShapeExpr([sv1, sv3]) + ) + shape_expr = derive_func_ret_shape([a1, a2], body) + assert isinstance(shape_expr, rx.RuntimeDepShape) + + +@tvm.script.ir_module +class VarExample: + @R.function + def func(a: Tensor) -> Tensor: + return R.add(a, a) + + @R.function + def main(x: Tensor, y: Tensor) -> Tensor: + z = R.add(x, y) + # no binding here + R.match_shape(x, (5, 5)) + with R.dataflow(): + q = R.add(z, z) + p = func(q) + r = R.match_shape(p, (5, 5)) + s = r + R.output(s) + return s + + +def test_all_vars(): + vars = all_vars(VarExample["func"]) + assert len(vars) == 1 + assert vars[0].name_hint == "a" + + var_names = var_name_set(all_vars(VarExample["main"])) + assert var_names == {"x", "y", "z", "p", "q", "r", "s"} + + +def test_bound_vars(): + vars = bound_vars(VarExample["func"]) + assert len(vars) == 1 + assert vars[0].name_hint == "a" + + # all the vars are bound + var_names = var_name_set(bound_vars(VarExample["main"])) + assert var_names == {"x", "y", "z", "p", "q", "r", "s"} + + # if we consider only the body, then the function arguments are not bound + body_names = var_name_set(bound_vars(VarExample["main"].body)) + assert body_names == {"z", "p", "q", "r", "s"} + + # if the argument isn't bound, then nothing is + assert len(bound_vars(VarExample["func"].body)) == 0 + + +def test_free_vars(): + # all the vars are bound + assert len(free_vars(VarExample["func"])) == 0 + assert len(free_vars(VarExample["main"])) == 0 + + # the arguments are free if we look only at the bodies + func_free = var_name_set(free_vars(VarExample["func"].body)) + main_free = var_name_set(free_vars(VarExample["main"].body)) + assert len(func_free) == 1 + assert len(main_free) == 2 + assert "a" in func_free + assert main_free == {"x", "y"} + + # function that captures vars + x = rx.Var("x", type_annotation=rx.DynTensorType(ndim=-1)) + y = rx.Var("y", type_annotation=rx.DynTensorType(ndim=-1)) + z = rx.Var("z", type_annotation=rx.DynTensorType(ndim=-1)) + inner = rx.Function( + [z], + rx.op.add(x, rx.op.add(y, z)), + ret_type=rx.DynTensorType(ndim=-1), + ret_shape=rx.RuntimeDepShape(), + ) + outer = rx.Function( + [x, y], + rx.Call(inner, [y]), + ret_type=rx.DynTensorType(ndim=-1), + ret_shape=rx.RuntimeDepShape(), + ) + assert len(free_vars(outer)) == 0 + assert var_name_set(free_vars(inner)) == {"x", "y"} + + +def test_all_global_vars(): + # there is one call to "func" + global_vars = all_global_vars(VarExample["main"]) + assert len(global_vars) == 1 + assert global_vars[0].name_hint == "func" + + gv1 = rx.GlobalVar("gv1") + gv2 = rx.GlobalVar("gv2") + gv3 = rx.GlobalVar("gv3") + call = rx.Call(gv1, [gv2, gv3]) + call_var_names = var_name_set(all_global_vars(call)) + assert call_var_names == {"gv1", "gv2", "gv3"} + + +def test_called_global_vars(): + # there is one call to "func" + global_vars = called_global_vars(VarExample["main"]) + assert len(global_vars) == 1 + assert global_vars[0].name_hint == "func" + + gv1 = rx.GlobalVar("gv1") + gv2 = rx.GlobalVar("gv2") + gv3 = rx.GlobalVar("gv3") + call = rx.Call(gv1, [gv2, gv3]) + call_vars = called_global_vars(call) + assert len(call_vars) == 1 + assert call_vars[0].name_hint == "gv1" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py new file mode 100644 index 0000000000..41a652f288 --- /dev/null +++ b/tests/python/relax/test_ast_printer.py @@ -0,0 +1,377 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations +import pytest +import re + +import tvm +from tvm import tir +from tvm import relax as rx +from tvm.relax.testing import dump_ast +from tvm.script import tir as T, relax as R + +import numpy as np + + +def strip_whitespace(text: str) -> str: + """ + Remove all whitespace to avoid reasoning about newlines and indents + """ + return re.sub(r"\s", "", text) + + +# test cases are mostly adapted from text_expr, only testing very basic properties + + +def test_var() -> None: + v0 = rx.Var("v0") + v0_str = dump_ast(v0) + assert v0_str == 'Var(name_hint="v0")' + + shape_anno = [54, 96] + type_anno = rx.DynTensorType(2, "float32") + v1 = rx.Var("v1", shape_anno, type_anno) + v1_no_annos = dump_ast(v1, include_shape_annotations=False, include_type_annotations=False) + assert v1_no_annos == 'Var(name_hint="v1")' + v1_annos = dump_ast(v1) + assert v1_annos != v1_no_annos + assert "PrimExpr" in v1_annos + assert "shape_" in v1_annos + assert "_checked_type_" in v1_annos + + +def test_dataflow_var() -> None: + v0 = rx.DataflowVar("v0") + v0_str = dump_ast(v0) + assert v0_str == 'DataflowVar(name_hint="v0")' + + shape_anno = [54, 96] + type_anno = rx.DynTensorType(2, "float16") + v1 = rx.DataflowVar("v1", shape_anno, type_anno) + v1_no_annos = dump_ast(v1, include_shape_annotations=False, include_type_annotations=False) + assert v1_no_annos == 'DataflowVar(name_hint="v1")' + v1_annos = dump_ast(v1) + assert v1_annos != v1_no_annos + assert "PrimExpr" in v1_annos + assert "shape_" in v1_annos + assert "_checked_type_" in v1_annos + + +def test_match_shape() -> None: + # match_shape([16, 8], [m, n]) + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + shape = rx.const([16, 8], "int32") + var = rx.Var("v0", type_annotation=rx.ShapeType()) + b0 = rx.MatchShape(shape, [m, n], var) + b0_str = dump_ast(b0) + assert b0_str.startswith("MatchShape(") + assert "Constant" in b0_str + assert "PrimExpr(value=`m: int32`)" in b0_str + assert "PrimExpr(value=`n: int32`)" in b0_str + assert "16" in b0_str + assert "8" in b0_str + assert b0_str != dump_ast(b0, include_type_annotations=False) + + # var1: Tensor((m, n), "float32") = + # match_shape(var0: Tensor(_, "float32"), [m, n]) + type_anno0 = rx.DynTensorType(-1, "float32") + value = rx.Var("value", type_annotation=type_anno0) + + shape_anno = [m, n] + type_anno = rx.DynTensorType(2, "float32") + var = rx.Var("v1", shape_anno, type_anno) + b1 = rx.MatchShape(value, [m, n], var) + b1_str = dump_ast(b1) + assert b1_str.startswith("MatchShape(") + assert "PrimExpr(value=`m: int32`)" in b1_str + assert "PrimExpr(value=`n: int32`)" in b1_str + assert b1_str != dump_ast(b1, include_type_annotations=False, include_shape_annotations=False) + + +def test_var_binding() -> None: + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b0 = rx.VarBinding(v0, val) + b0_str = dump_ast(b0) + assert b0_str.startswith("VarBinding(") + assert 'var=Var(name_hint="v0")' in b0_str + assert "value=" in b0_str + assert "Constant(" in b0_str + + +def test_binding_block() -> None: + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchShape(shape, [m, n], rx.Var("v0")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.BindingBlock([b0, b1]) + block0_str = dump_ast(block0) + assert block0_str.startswith("BindingBlock(") + assert "bindings=" in block0_str + assert "VarBinding(" in block0_str + assert "MatchShape(" in block0_str + assert '"v0"' in block0_str + + +def test_dataflow_block() -> None: + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchShape(shape, [m, n], rx.Var("v0")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.DataflowBlock([b0, b1]) + block0_str = dump_ast(block0) + assert block0_str.startswith("DataflowBlock(") + assert "bindings=" in block0_str + assert "VarBinding(" in block0_str + assert "MatchShape(" in block0_str + assert '"v0"' in block0_str + + +def test_seq_expr() -> None: + x = rx.Var("foo") + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) + seqe_str = dump_ast(seqe) + assert seqe_str.startswith("SeqExpr(") + assert "blocks=" in seqe_str + assert "BindingBlock(" in seqe_str + assert "VarBinding(" in seqe_str + assert "Constant(" in seqe_str + assert 'var=Var(name_hint="foo")' in seqe_str + assert "value=Constant(data=1)" in seqe_str + assert "body=" in seqe_str + + +def test_shape_expr() -> None: + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + s = rx.ShapeExpr([m, n]) + s_str = dump_ast(s) + assert s_str.startswith("ShapeExpr(") + assert "values=" in s_str + assert "PrimExpr(value=`m: int32`)" in s_str + assert "PrimExpr(value=`n: int32`)" in s_str + + +def test_func(): + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("foo", type_annotation=type_anno) + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) + ret_type = rx.DynTensorType(-1, "float32") + ret_shape = rx.RuntimeDepShape() + func = rx.Function([x], seqe, ret_type, ret_shape) + func = func.with_attr("global_symbol", "func") + + func_str = dump_ast(func) + assert func_str.startswith("Function(") + assert "params=" in func_str + assert "body=" in func_str + assert "ret_type=" in func_str + assert "ret_shape=" in func_str + assert "attrs=" in func_str + assert '"global_symbol": "func"' in func_str + assert "SeqExpr(" in func_str + assert "blocks=" in func_str + assert "VarBinding(" in func_str + assert func_str != dump_ast(func, include_type_annotations=False) + + +def test_shape_of(): + v0 = rx.Var("v0") + s0 = v0.shape + s0_str = dump_ast(s0) + assert s0_str.startswith("Call(") + assert 'op=Op(name="relax.shape_of")' in s0_str + assert "args=" in s0_str + assert 'Var(name_hint="v0")' in s0_str + + shape_anno = [96, 54] + v1 = rx.Var("v1", shape_anno) + s1 = v1.shape + s1_str = dump_ast(s1) + assert s1_str.startswith("ShapeExpr("), s1_str + assert "values=" in s1_str + assert "PrimExpr(value=`96`)" in s1_str + assert "PrimExpr(value=`54`)" in s1_str + + +def test_shape_expr(): + shape_expr = rx.ShapeExpr([10, 20]) + shape_expr_str = dump_ast(shape_expr) + assert shape_expr_str.startswith("ShapeExpr(") + assert "values" in shape_expr_str + assert "PrimExpr(value=`10`)" in shape_expr_str + assert "PrimExpr(value=`20`)" in shape_expr_str + + +def test_call_packed(): + # test case from test_parser + @R.function + def f( + x: Tensor((32, m), "float32"), + y: Tensor((m, k), "float32"), + r: Tensor(_, "int64"), + ) -> Object: + z: Tensor((32, k), "float32") = nn.matmul(x, y, units=None) + w: Tensor(None, _) = multiply(z, z) + q: Tensor(None, _, ndim=2) = add(w, w) + t = subtract(w, z) + sh: Shape = t.shape + o: Object = relax.call_packed("contrib.tensor_array_stack", x, y, type_args=(Object)) + return o + + # checking that the call_packed call is turned into a call to an extern func + f_str = strip_whitespace( + dump_ast( + f, + include_type_annotations=False, + include_shape_annotations=False, + include_call_attrs=False, + ) + ) + extern_call = strip_whitespace( + """ + Call( + op=ExternFunc(global_symbol="contrib.tensor_array_stack"), + args=[ + Var(name_hint="x"), + Var(name_hint="y") + ], + type_args=[ObjectType()] + ) + """ + ) + assert extern_call in f_str + # check that the op call is there too + op_call = strip_whitespace( + """ + Call( + op=Op(name="nn.matmul"), + args=[ + Var(name_hint="x"), + Var(name_hint="y") + ] + ) + """ + ) + assert op_call in f_str + # the function has an annotated return type + assert "ret_type=ObjectType()" in f_str + + # the op call has attributes so let's check those too + f_str_complete = strip_whitespace(dump_ast(f)) + assert f_str != f_str_complete + attrs_str = strip_whitespace( + """ + attrs={ + "units": None, + "out_dtype": "", + "transpose_a": 0, + "transpose_b": 0 + } + """ + ) + assert attrs_str in f_str_complete + + +def test_call_tir(): + # also from test_parser + @R.function + def foo(x: Tensor((m, n), "float32")): + gv0 = relax.call_tir("test.op.identity", (x,), (m, n), dtype="float32") + return gv0 + + foo_str = strip_whitespace( + dump_ast( + foo, + include_type_annotations=False, + include_shape_annotations=False, + include_call_attrs=False, + ) + ) + # call_tir is an op in Relax and it takes an extern func as an argument + call_tir_text = strip_whitespace( + """ + Call( + op=Op(name="relax.call_tir"), + args=[ + ExternFunc(global_symbol="test.op.identity"), + Tuple(fields=[Var(name_hint="x")]), + ShapeExpr(values=[ + PrimExpr(value=`m: int64`), + PrimExpr(value=`n: int64`) + ]) + ], + type_args=[DynTensorType(ndim=2, dtype=float32)] + ) + """ + ) + assert foo_str.startswith('Function(params=[Var(name_hint="x")]') + assert call_tir_text in foo_str + + +def test_operators(): + # the operator attributes need to be registered to work in the printer + + @R.function + def foo(x: Tensor): + return relax.unique(x, sorted=True) + + foo_str = strip_whitespace( + dump_ast( + foo, + include_type_annotations=False, + include_shape_annotations=False, + ) + ) + # checking that the attributes are present + assert '"sorted":1' in foo_str + assert '"return_inverse"' in foo_str + assert '"return_counts"' in foo_str + assert '"dim"' in foo_str + + @R.function + def bar(x: Tensor): + return relax.print(x, format="{}") + + bar_str = strip_whitespace( + dump_ast( + bar, + include_type_annotations=False, + include_shape_annotations=False, + ) + ) + print_attrs_str = strip_whitespace('{"format": "{}"}') + assert print_attrs_str in bar_str + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_autotir_integration.py b/tests/python/relax/test_autotir_integration.py new file mode 100644 index 0000000000..7a061b0bbd --- /dev/null +++ b/tests/python/relax/test_autotir_integration.py @@ -0,0 +1,233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import tempfile +import time + +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import meta_schedule as ms +from tvm import relax, transform +from tvm.ir.module import IRModule +from tvm.script import relax as R +from tvm.script import tir as T +from tvm.target.target import Target + +# Test case with dynamic shape. +# Tuning with dynamic shape is not supported yet. +""" +class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m,n)) + B = T.match_buffer(y, (n,k)) + C = T.match_buffer(z, (m,k)) + + for (i0, j0, k0) in T.grid(m,n,k): + with T.block(): + i,j,k = T.axis.remap("SSR", [i0,j0,k0]) + with T.init(): + C[i,j] = 0.0 + C[i,j] += A[i,k] * B[j,k] + + @T.prim_func + def tir_relu(x:T.handle, y:T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + m = T.var("int32") + n = T.var("int32") + A = T.match_buffer(x, (m,n)) + B = T.match_buffer(y, (m,n)) + for (i,j) in T.grid(m,n): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x:Tensor((m,n), "float32"), w:Tensor((n,k), "float32")) -> Tensor: + with R.dataflow(): + sh = relax.call_packed("vm.builtin.shape_of", x) + x0 = relax.match_shape(sh, (m, n)) + sh1 = relax.call_packed("vm.builtin.shape_of", w) + x1 = relax.match_shape(sh1, (n, k)) + lv0 = R.call_tir(tir_matmul, (x, w), (m, k), dtype="float32") + lv1 = R.call_tir(tir_relu, (lv0), (m, k), dtype="float32) + relax.output(lv1) + return lv1 +""" + + +@pytest.mark.parametrize("dev", ["cpu"]) +def test_autotir(dev: str): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + k = T.var("int32") + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + C = T.match_buffer(z, (32, 32)) + + for (i0, j0, k0) in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + @T.prim_func + def tir_relu(x: T.handle, y: T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + m = T.var("int32") + n = T.var("int32") + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + for (i, j) in T.grid(32, 32): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") + lv1 = R.call_tir(tir_relu, (lv0), (32, 32), dtype="float32") + relax.output(lv1) + return lv1 + + mod = InputModule + assert isinstance(mod, IRModule) + + if dev == "cpu": + target = Target("llvm --num-cores=16") + dev = tvm.cpu() + elif dev == "cuda": + target = Target("nvidia/nvidia-t4") + dev = tvm.cuda() + + database = ms.database.MemoryDatabase() + + with tempfile.TemporaryDirectory() as work_dir: + db = ms.relax_integration.tune_relax( + mod=mod, + target=target, + params=None, + num_trials_per_iter=2, + max_trials_per_task=4, + max_trials_global=4, + work_dir=work_dir, + database=database, + ) + relax_ex = ms.relax_integration.compile_relax( + db, + mod=mod, + target=target, + params=None, + ) + + if dev == "cpu": + with transform.PassContext(opt_level=3): + ex0 = relax.vm.build(mod, target) + vm0 = relax.VirtualMachine(ex0, dev) + + # Measure the performance w/o tuning log + tic = time.time() + vm0["main"](data, weight) + toc = time.time() + e0 = toc - tic + print(f"w/o tuning: {e0}") + + vm1 = relax.VirtualMachine(relax_ex, dev) + + data = tvm.nd.array(np.random.rand(32, 32).astype(np.float32), dev) + weight = tvm.nd.array(np.random.rand(32, 32).astype(np.float32), dev) + + # Measure the performance w/ tuning log + tic = time.time() + vm1["main"](data, weight) + toc = time.time() + e1 = toc - tic + print(f"w/ tuning: {e1}") + + +@tvm.testing.requires_gpu +def test_autotir_gpu(): + test_autotir("cuda") + + +def test_meta_schedule_extract_tasks(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add1(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + for i, j in T.grid(128, 128): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + @T.prim_func + def add2(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + for i, j in T.grid(128, 128): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 2.0 + + # It is intentional that `add3` equals `add1`, in order to test the deduplication + # correctness. + @T.prim_func + def add3(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + for i, j in T.grid(128, 128): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + @T.prim_func + def multiply1(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + for i, j in T.grid(128, 128): + with T.block("multiply"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + @R.function + def main(x: Tensor((128, 128), "float32")) -> Tensor(_, "float32"): + with R.dataflow(): + lv0 = R.call_tir(add1, (x,), (128, 128), dtype="float32") + lv1 = R.call_tir(multiply1, (lv0,), (128, 128), dtype="float32") + lv2 = R.call_tir(add2, (lv1,), (128, 128), dtype="float32") + lv3 = R.call_tir(multiply1, (lv2,), (128, 128), dtype="float32") + lv4 = R.call_tir(add3, (lv3,), (128, 128), dtype="float32") + gv = R.call_tir(add1, (lv4,), (128, 128), dtype="float32") + relax.output(gv) + return gv + + tasks = ms.relax_integration.extract_tasks(Module, Target("llvm --num-cores=16")) + expected_weights = {"add1": 3, "add2": 1, "multiply1": 2} + assert len(tasks) == len(expected_weights) + for task in tasks: + assert task.task_name in expected_weights + assert expected_weights[task.task_name] == task.weight + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_binding_rewrite.py b/tests/python/relax/test_binding_rewrite.py new file mode 100644 index 0000000000..86959cf0a3 --- /dev/null +++ b/tests/python/relax/test_binding_rewrite.py @@ -0,0 +1,333 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +import pytest + +import re + +import tvm +from tvm._ffi.base import TVMError +from tvm.relax.binding_rewrite import DataflowBlockRewrite +from tvm.relax.analysis import name_to_binding +from tvm.relax.expr import DataflowVar, Var +from tvm.script import relax as R + + +@tvm.script.ir_module +class Identity: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + +def assert_immutability(rwt, original_dfb, original_root_fn): + assert rwt.mutated_dfb() != original_dfb + assert rwt.mutated_root_fn() != original_root_fn + assert rwt.mutated_root_fn().body.blocks[0] != original_dfb + assert rwt.mutated_root_fn().body.blocks[0] == rwt.mutated_dfb() + + +def test_null_construct(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + DataflowBlockRewrite(dfb, root_fn) + + +def test_simple_add(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(name="tmp", expr=Identity["main"].params[0], is_dfvar=True) + + assert_immutability(rwt, dfb, root_fn) + + # check "tmp" added + assert "tmp" in name_to_binding(rwt.mutated_root_fn()) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + tmp: Tensor((32, 32), "float32") = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_simple_auto_add_var(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(root_fn.params[0], is_dfvar=False) + + assert isinstance(rwt.mutated_dfb().bindings[-1].var, Var) + + assert_immutability(rwt, dfb, root_fn) + + +def test_simple_auto_add_dfvar(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(root_fn.params[0], is_dfvar=True) + + assert isinstance(rwt.mutated_dfb().bindings[-1].var, DataflowVar) + + # immutatbility + assert_immutability(rwt, dfb, root_fn) + + +def test_simple_remove_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + unused = lv0 + R.output(lv0) + return lv0 + + root_fn = IdentityUnused["main"] + dfb = root_fn.body.blocks[0] + + n2binding = name_to_binding(IdentityUnused["main"]) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(n2binding["unused"][0].var) + + assert_immutability(rwt, dfb, root_fn) + + # check "unused" removed + assert "unused" not in name_to_binding(rwt.mutated_root_fn()) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_remove_unused_undef(): + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + with pytest.raises(TVMError): + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(Var("whatever")) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(Var("whatever"), allow_undef=True) + + assert root_fn == rwt.mutated_root_fn() + + +def test_simple_rm_all_unused(): + @tvm.script.ir_module + class IdentityUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + unused0 = lv0 + unused1 = lv0 + R.output(lv0) + return lv0 + + root_fn = IdentityUnused["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +@tvm.script.ir_module +class DeadDFBlock: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor((32, 32), "float32"): + with R.dataflow(): + lv0 = x + R.output(lv0) + return x + + +def test_empty_dfb_after_removal(): + root_fn = DeadDFBlock["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_unused(DeadDFBlock["main"].body.blocks[0].bindings[0].var) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor((32, 32), "float32"): + return x + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_empty_dfb_after_all_removal(): + dfb = DeadDFBlock["main"].body.blocks[0] + root_fn = DeadDFBlock["main"] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor((32, 32), "float32"): + return x + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_chained_rm_all_unused(): + @tvm.script.ir_module + class IdentityChainedUnused: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_tir(my_sigmoid, (x,), (32, 32), dtype="float32") + unused1 = R.call_tir(my_sigmoid, (unused0,), (32, 32), dtype="float32") + R.output(lv0) + return lv0 + + root_fn = IdentityChainedUnused["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.remove_all_unused() + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) + + +def test_simple_replace_all_uses(): + @tvm.script.ir_module + class Lv0To1: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor((32, 32), "float32"): + # lv0 => lv1 + # / \ + # lv2 lv3 + # \ / + # lv4 + with R.dataflow(): + lv0: Tensor((32, 32), "float32") = R.call_tir( + my_relu, (x,), (32, 32), dtype="float32" + ) + lv1: Tensor((32, 32), "float32") = R.call_tir( + my_sigmoid, (x,), (32, 32), dtype="float32" + ) + lv2: Tensor((32, 32), "float32") = R.call_tir( + my_add, (x, lv0), (32, 32), dtype="float32" + ) + lv3: Tensor((32, 32), "float32") = R.call_tir( + my_mul, (x, lv0), (32, 32), dtype="float32" + ) + lv4: Tensor((32, 32), "float32") = R.call_tir( + my_whatever, (lv2, lv3), (32, 32), dtype="float32" + ) + R.output(lv4) + return lv4 + + root_fn = Lv0To1["main"] + dfb = root_fn.body.blocks[0] + + n2binding = name_to_binding(root_fn) + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.replace_all_uses(n2binding["lv0"][0].var, n2binding["lv1"][0].var) + rwt.remove_unused(n2binding["lv0"][0].var) + + assert_immutability(rwt, dfb, root_fn) + + n2binding_after = name_to_binding(rwt.mutated_root_fn()) + assert "lv0" not in n2binding_after + + +def test_simple_module_update(): + @tvm.script.ir_module + class Identity: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + root_fn = Identity["main"] + dfb = root_fn.body.blocks[0] + + rwt = DataflowBlockRewrite(dfb, root_fn) + rwt.add(name="tmp", expr=root_fn.params[0], is_dfvar=True) + + new_ir = rwt.mutate_irmodule(Identity) + + # immutatbility + assert new_ir != Identity + assert 2 == len(new_ir["main"].body.blocks[0].bindings) + + @tvm.script.ir_module + class GroundTruth: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = x + tmp: Tensor((32, 32), "float32") = x + R.output(lv0) + return lv0 + + tvm.ir.assert_structural_equal(new_ir, GroundTruth) diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py new file mode 100644 index 0000000000..c5301abf9f --- /dev/null +++ b/tests/python/relax/test_blockbuilder.py @@ -0,0 +1,638 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations # must import to defer parsing of annotations +import pytest +import tvm +from tvm import tir, te +from tvm import relay +from tvm import relax as rx +from tvm.tir.function import PrimFunc + +from tvm.ir.base import assert_structural_equal +from tvm.relax import ExternFunc, ShapeExpr, Tuple +from tvm import topi +from tvm.relax.testing import nn +from tvm.script import relax as R, tir as T + + +@tvm.register_func("test.blockbuilder.nop") +def nop(): + pass + + +def test_block_builder(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + bb = rx.BlockBuilder() + + bb._begin_binding_block() + gv0 = bb.emit(rx.op.add(x, y)) + bb._begin_dataflow_block() + lv0 = bb.emit(rx.op.multiply(gv0, y)) + gv1 = bb.emit_output(rx.op.multiply(lv0, lv0)) + b0 = bb._end_block() + bb._begin_dataflow_block() + lv1 = bb.emit(rx.op.multiply(gv0, y)) + gv2 = bb.emit_output(rx.op.multiply(lv1, lv1)) + b1 = bb._end_block() + gv3 = bb.emit(rx.op.add(x, y)) + b2 = bb._end_block() + + assert isinstance(b0, rx.DataflowBlock) + assert isinstance(b1, rx.DataflowBlock) + assert not isinstance(b2, rx.DataflowBlock) + + +def test_function_single_block(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow() as df: + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + lv1 = bb.emit(rx.op.multiply(lv0, y)) + assert lv1.name_hint == "lv1" + gv0 = bb.emit_output(lv1) + assert gv0.name_hint == "gv" + bb.emit_func_output(gv0) + + func = bb.get()["func"] + assert func.params[0] == x + assert func.params[1] == y + assert func.body.body == gv0 + assert gv0.shape[0] == m + assert gv0.shape[1] == n + assert gv0.checked_type.ndim == 2 + assert gv0.checked_type.dtype == "float16" + assert len(func.body.blocks) == 1 + assert len(func.body.blocks[0].bindings) == 3 + + +def test_function_multi_blocks(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow() as df: + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + gv0 = bb.emit_output(lv0) + assert gv0.name_hint == "gv" + gv1 = bb.emit(rx.op.add(gv0, gv0)) + assert gv1.name_hint == "gv1" + with bb.dataflow() as df: + lv1 = bb.emit(rx.op.add(gv1, gv1)) + assert lv1.name_hint == "lv1" + gv2 = bb.emit_output(gv1) + bb.emit_func_output(gv2) + + func = bb.get()["func"] + assert gv2.shape[0] == m + assert gv2.shape[1] == n + assert gv2.checked_type.ndim == 2 + assert gv2.checked_type.dtype == "float16" + assert func.params[0] == x + assert func.params[1] == y + assert func.attrs["global_symbol"] == "func" + assert func.body.body == gv2 + assert len(func.body.blocks) == 3 + assert len(func.body.blocks[0].bindings) == 2 + assert len(func.body.blocks[1].bindings) == 1 + assert len(func.body.blocks[2].bindings) == 2 + + +def test_multi_functions(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + bb = rx.BlockBuilder() + + with bb.function("func1", [x, y]): + with bb.dataflow() as df: + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv" + gv0 = bb.emit_output(lv0) + bb.emit_func_output(gv0) + + with bb.function("func2", [x, y]): + with bb.dataflow() as df: + lv0 = bb.emit(rx.op.add(y, x)) + # TODO(@yuchen): enable block builder to reset local var unique name map + assert lv0.name_hint == "lv1" + gv0 = bb.emit_output(lv0) + bb.emit_func_output(gv0) + + mod = bb.get() + func1 = mod["func1"] + assert func1.params[0] == x + assert func1.params[1] == y + assert func1.attrs["global_symbol"] == "func1" + assert len(func1.body.blocks) == 1 + func2 = mod["func2"] + assert func2.params[0] == x + assert func2.params[1] == y + assert func2.attrs["global_symbol"] == "func2" + assert len(func2.body.blocks) == 1 + + +def test_block_builder_input_mod(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def before_main(x: Tensor((m, n), "float32"), w: Tensor((n, k), "float32")) -> Tensor: + gv0 = R.call_tir(tir_matmul, (x, w), (m, k), dtype="float32") + return gv0 + + @R.function + def after_main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: + gv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") + return gv0 + + input_mod = InputModule + bb = rx.BlockBuilder(input_mod) + var_main = input_mod.get_global_var("before_main") + bb.update_func(var_main, after_main) + + context_mod = bb.get() + assert len(context_mod.get_global_vars()) == 2 + var_before_main = context_mod.get_global_var("before_main") + assert var_main == var_before_main + assert_structural_equal(context_mod[var_before_main], after_main) + + +def test_binary_shape_type_deduction(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + k = tir.Var("k", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, 1], type_anno0) + y = rx.Var("y", [n], type_anno1) + z = rx.Var("z", [5], type_anno1) + w = rx.Var("w", [k], type_anno1) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y, z, w]): + with bb.dataflow() as df: + lv0 = bb.emit(rx.op.add(x, y)) + assert lv0.shape[0] == m + assert lv0.shape[1] == n + assert isinstance(lv0.checked_type, rx.DynTensorType) + assert lv0.checked_type.ndim == 2 + assert lv0.checked_type.dtype == "float16" + + lv1 = bb.emit(rx.op.multiply(x, z)) + assert lv1.shape[0] == m + assert lv1.shape[1] == 5 + assert isinstance(lv1.checked_type, rx.DynTensorType) + assert lv1.checked_type.ndim == 2 + assert lv1.checked_type.dtype == "float16" + + lv2 = bb.emit(rx.op.multiply(z, w)) + assert isinstance(lv2.shape, rx.Call) + assert isinstance(lv2.checked_type, rx.DynTensorType) + assert lv2.checked_type.ndim == 1 + assert lv2.checked_type.dtype == "float16" + + lv3 = bb.emit(rx.op.multiply(y, w)) + assert isinstance(lv3.shape, rx.Call) + assert isinstance(lv3.checked_type, rx.DynTensorType) + assert lv3.checked_type.ndim == 1 + assert lv3.checked_type.dtype == "float16" + gv0 = bb.emit_output(lv3) + bb.emit_func_output(gv0) + assert isinstance(gv0.shape, rx.Call) + assert isinstance(gv0.checked_type, rx.DynTensorType) + assert gv0.checked_type.ndim == 1 + assert gv0.checked_type.dtype == "float16" + + +def test_emit_match_shape(): + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + type_anno0 = rx.DynTensorType(-1, "float32") + x = rx.Var("tensor_value", type_annotation=type_anno0) + shape_anno = [16, 8] + y = rx.Var("shape_value", type_annotation=rx.ShapeType(), shape_annotation=shape_anno) + bb = rx.BlockBuilder() + + with bb.function("func", [x, y]): + with bb.dataflow() as df: + # lv0: Tensor((m, n), "float32") = + # match_shape(x: Tensor(_, "float32"], [m, n)) + lv0 = bb.match_shape(x, [m, n]) + assert isinstance(lv0, rx.DataflowVar) + assert lv0.shape[0] == m + assert lv0.shape[1] == n + assert lv0.checked_type.ndim == 2 + assert lv0.checked_type.dtype == "float32" + + # lv1: Shape = match_shape(shape, [m, n]) + lv1 = bb.match_shape(y, [m, n]) + assert lv1.checked_type == rx.ShapeType() + gv0 = bb.emit_output(lv1) + + bb.emit_func_output(gv0) + func = bb.get()["func"] + block = func.body.blocks[0] + b0, b1 = block.bindings[:2] + assert isinstance(b0, rx.MatchShape) + assert isinstance(b1, rx.MatchShape) + + assert b0.value == x + assert b0.pattern[0] == m + assert b0.pattern[1] == n + assert b0.var == lv0 + + assert b1.value == y + assert b1.pattern[0] == m + assert b1.pattern[1] == n + assert b1.var == lv1 + + +def test_emit_match_shape_binding_in_dataflow_block(): + bb = rx.BlockBuilder() + + x = rx.Var("x", type_annotation=rx.DynTensorType(-1, "float32")) + m = tir.Var("m", dtype="int32") + gv = rx.Var("gv") + match_shape = rx.MatchShape(x, (m,), gv) + + with bb.function("main", [x]): + with bb.dataflow(): + bb.match_shape_binding(match_shape) + bb.emit_output(gv) + bb.emit_func_output(x) + + func = bb.get()["main"] + block = func.body.blocks[0] + b0 = block.bindings[0] + assert isinstance(b0, rx.MatchShape) + + assert b0.value == x + assert b0.pattern[0] == m + assert b0.var == gv + + +def test_normalize(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + bb = rx.BlockBuilder() + + add_call = rx.op.multiply(x, y) + assert isinstance(add_call.shape, rx.Call) + + bb.normalize(add_call) + assert isinstance(add_call.shape, rx.ShapeExpr) + assert add_call.shape[0] == m + assert add_call.shape[1] == n + + +def test_call_te(): + bb = rx.BlockBuilder() + dtype = rx.DynTensorType(ndim=2, dtype="float32") + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = rx.Var("x", [n, m], dtype) + y = rx.Var("y", [n, m], dtype) + z = rx.Var("z", [n, m], dtype) + + def te_func(args, args_dict, msg): + A, B = args + C = args_dict["C"] + D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j]) + return E + + with bb.function("rx_func", [x, y, z]): + with bb.dataflow(): + out = bb.emit_output(bb.call_te(te_func, [x, y], {"C": z}, msg="hello")) + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_func"] + + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert rx_func.params[2] == z + assert rx_func.attrs["global_symbol"] == "rx_func" + assert rx_func.body.body == out + assert len(rx_func.body.blocks) == 1 + assert len(rx_func.body.blocks[0].bindings) == 1 + + +def test_emit_te(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", [n, m], type_anno) + y = rx.Var("y", [n, m], type_anno) + z = rx.Var("z", [n, m], type_anno) + + def te_func(args, args_dict, msg): + A, B = args + C = args_dict["C"] + D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) + E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j]) + return E + + with bb.function("rx_func", [x, y, z]): + out = bb.emit_te(te_func, [x, y], {"C": z}, msg="hello") + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_func"] + + def get_tir_func(): + A = te.placeholder((n, m), dtype="float32", name="A") + B = te.placeholder((n, m), dtype="float32", name="B") + C = te.placeholder((n, m), dtype="float32", name="C") + out = te_func((A, B), {"C": C}, "") + return tvm.te.create_prim_func([A, B, C, out]) + + # check TIR structure matches expected + assert_structural_equal(mod["te_func"].body, get_tir_func().body) + + # check Relax function calls TIR function with call_tir call + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert rx_func.params[2] == z + assert rx_func.attrs["global_symbol"] == "rx_func" + assert rx_func.body.body == out + assert len(rx_func.body.blocks) == 1 + assert len(rx_func.body.blocks[0].bindings) == 1 + + call_node = rx_func.body.blocks[0].bindings[0].value + assert isinstance(call_node, rx.Call) + assert call_node.op == relay.op.get("relax.call_tir") + assert len(call_node.args) == 3 + assert call_node.args[0].name_hint == "te_func" + assert call_node.args[1][0] == x + assert call_node.args[1][1] == y + assert call_node.args[1][2] == z + + +def test_emit_te_multiple(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", [n, m], type_anno) + y = rx.Var("y", [n, m], type_anno) + z = rx.Var("z", [128, m], type_anno) + + def te_func(A): + B = te.compute((128, 128), lambda i, j: A[i, j] + 1) + return B + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x) + y1 = bb.emit_te(te_func, y) + z1 = bb.emit_te(te_func, z) + bb.emit_func_output(z1) + + mod = bb.get() + rx_func = mod["rx_func"] + + prim_func = [] + for gv in mod.get_global_vars(): + if isinstance(mod[gv], PrimFunc): + prim_func.append(mod[gv]) + + # only two PrimFuncs were generated since two of them are equal so got deduped + assert len(prim_func) == 2 + assert rx_func.body.blocks[0].bindings[0].value.args[0].name_hint == "te_func" + assert rx_func.body.blocks[0].bindings[1].value.args[0].name_hint == "te_func" + assert rx_func.body.blocks[0].bindings[2].value.args[0].name_hint == "te_func1" + + +def test_emit_te_multiple_output(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", [n, m], type_anno) + + def te_func(A): + B0, B1 = te.compute((n, m), lambda i, j: (A[i, j] + 1, A[i, j] * 2), name="B") + return (B0, B1) + + with bb.function("rx_func", [x]): + y = bb.emit_te(te_func, x) + z = rx.TupleGetItem(y, 0) + bb.emit_func_output([y, z]) + + rx_func = bb.get()["rx_func"] + + # check call tir output shape is a Tuple of ShapeExpr + assert rx_func.params[0] == x + assert rx_func.attrs["global_symbol"] == "rx_func" + call_node = rx_func.body.blocks[0].bindings[0].value + assert call_node.op == relay.op.get("relax.call_tir") + assert call_node.args[0].name_hint == "te_func" + assert isinstance(call_node.args[2], rx.Tuple) + assert len(call_node.args[2]) == 2 + assert isinstance(call_node.args[2][0], rx.ShapeExpr) + assert isinstance(call_node.args[2][1], rx.ShapeExpr) + + +def test_emit_te_extern(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", [n, m], type_anno) + y = rx.Var("y", [m, n], type_anno) + + with bb.function("rx_cblas_matmul", [x, y]): + out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) + bb.emit_func_output(out) + + mod = bb.get() + rx_func = mod["rx_cblas_matmul"] + + # check Relax function calls TIR function with call_tir call + assert rx_func.params[0] == x + assert rx_func.params[1] == y + assert len(rx_func.body.blocks) == 1 + call_node = rx_func.body.blocks[0].bindings[0].value + assert isinstance(call_node, rx.Call) + assert call_node.op == relay.op.get("relax.call_tir") + assert len(call_node.args) == 3 + assert call_node.args[0].name_hint == "matmul" + assert call_node.args[1][0] == x + assert call_node.args[1][1] == y + assert call_node.args[2][0] == n + assert call_node.args[2][1] == n + + +def test_emit_tuple_get_item(): + bb = rx.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + + with bb.function("rx_func"): + data = nn.Placeholder((n, m, 224, 224), name="x") + gamma = nn.Parameter((m,)) + beta = nn.Parameter((m,)) + moving_mean = nn.Parameter((m,)) + moving_var = nn.Parameter((m,)) + y = bb.emit_te(topi.nn.batch_norm, data, gamma, beta, moving_mean, moving_var) + + z = bb.emit(rx.TupleGetItem(y, 0)) + assert z.shape[0] == n + assert z.shape[1] == m + assert z.shape[2] == 224 + assert z.shape[3] == 224 + assert z.checked_type.ndim == 4 + assert z.checked_type.dtype == "float32" + + w = bb.emit(rx.TupleGetItem(y, 1)) + assert w.shape[0] == m + assert w.checked_type.dtype == "float32" + + o = bb.emit(rx.TupleGetItem(y, 2)) + assert o.shape[0] == m + assert o.checked_type.dtype == "float32" + bb.emit_func_output([y, w], params=[data, gamma, beta, moving_mean, moving_var]) + + func = bb.get()["rx_func"] + assert len(func.body.blocks[0].bindings) == 4 + + +def test_nested_function_fail(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, x)) + with bb.function("func1", [x, y]): + gv1 = bb.emit(rx.op.add(x, x)) + bb.emit_func_output(gv0) + + +def test_emit_func_output_twice_fail(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0) + bb.emit_func_output(gv0) + + +def test_func_params_twice_fail(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0, [x]) + + +def test_no_func_params_fail(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + type_anno0 = rx.DynTensorType(ndim=2, dtype="float16") + type_anno1 = rx.DynTensorType(ndim=1, dtype="float16") + x = rx.Var("x", [m, n], type_anno0) + y = rx.Var("y", [n], type_anno1) + bb = rx.BlockBuilder() + + with pytest.raises(RuntimeError): + with bb.function("func"): + gv0 = bb.emit(rx.Call(ExternFunc("test.blockbuilder.nop"), None)) + bb.emit_func_output(gv0) + + +def test_block_builder_scope_recovery(): + bb = rx.BlockBuilder() + + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", [n, m], type_anno) + y = rx.Var("y", [m, n], type_anno) + + with pytest.raises(RuntimeError): + # this line fails + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + + # current should be recovered + assert rx.BlockBuilder.current() is None + + # second attempt to do it correctly. + with bb.function("func", [x, y]): + gv0 = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv0) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py new file mode 100644 index 0000000000..5138dc03d7 --- /dev/null +++ b/tests/python/relax/test_dataflow_pattern.py @@ -0,0 +1,840 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +import pytest + +from tvm import relay +from tvm.relax.dpl import * +from tvm.relax.analysis import get_var2val +from tvm import relax as rx, tir +from tvm.script import relax as R, tir as T + + +@tvm.script.ir_module +class Module: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + k = T.var("int32") + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + C = T.match_buffer(z, (32, 32)) + + for (i0, j0, k0) in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + @T.prim_func + def tir_relu(x: T.handle, y: T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + for (i, j) in T.grid(32, 32): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") + lv1 = R.call_tir(tir_relu, (lv0), (32, 32), dtype="float32") + R.output(lv1) + return lv1 + + +main_fn = Module["main"] +bindings = main_fn.body.blocks[0].bindings + +## Node-wise Matching +def test_expr_pattern(): + ep = is_expr(rx.Var("x")) + assert isinstance(ep, ExprPattern) + assert isinstance(ep.expr, rx.Var) + + +def test_var_pattern(): + v = is_var("x") + assert isinstance(v, VarPattern) + assert v.name == "x" + assert v.match(rx.Var("x")) + assert is_var().match(rx.Var("x")) + assert is_var().match(rx.DataflowVar("x")) # DataflowVar is also a Var + assert not v.match(rx.GlobalVar("x")) + + +def test_dataflow_var_pattern(): + v = is_dfv("x") + assert isinstance(v, DataflowVarPattern) + assert v.name == "x" + assert v.match(rx.DataflowVar("x")) + assert not v.match(rx.GlobalVar("x")) + assert is_dfv().match(bindings[0].var) + + +def test_global_var_pattern(): + assert is_gv("x").match(rx.GlobalVar("x")) + assert is_gv().match(rx.GlobalVar("x")) + assert not is_gv("x").match(rx.GlobalVar("y")) + assert not is_gv("x").match(rx.Var("x")) + + +def test_constant_pattern(): + c = is_const() + assert isinstance(c, ConstantPattern) + assert c.match(rx.const([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]])) + + +def test_wildcard_pattern(): + wc = wildcard() + assert isinstance(wc, WildcardPattern) + assert wc.match(rx.Var("x")) + + +def test_call_pattern(): + wc1 = wildcard() + wc2 = wildcard() + c = is_op("relax.add")(wc1, wc2) + assert isinstance(c, CallPattern) + assert isinstance(c.args[0], WildcardPattern) + assert isinstance(c.args[1], WildcardPattern) + assert c.match(rx.op.add(rx.Var("x"), rx.Var("y"))) + + +def test_function_pattern(): + wc1 = wildcard() + wc2 = wildcard() + f = FunctionPattern([wc1, wc2], is_op("relax.add")(wc1, wc2)) + assert isinstance(f, FunctionPattern) + assert isinstance(f.params[0], WildcardPattern) + assert isinstance(f.params[1], WildcardPattern) + assert isinstance(f.body, CallPattern) + assert isinstance(f.body.args[0], WildcardPattern) + assert isinstance(f.body.args[1], WildcardPattern) + ttype = rx.DynTensorType(-1, "float32") + x = rx.Var("x", type_annotation=ttype) + y = rx.Var("y", type_annotation=ttype) + assert f.match( + rx.Function([x, y], rx.op.add(x, y), ret_type=ttype, ret_shape=rx.RuntimeDepShape()) + ) + assert not f.match( + rx.Function([x, y], rx.op.multiply(x, y), ret_type=ttype, ret_shape=rx.RuntimeDepShape()) + ) + + +def test_tuple_pattern(): + wc1 = wildcard() + wc2 = is_dfv() + t = is_tuple([wc1, wc2]) + assert isinstance(t, TuplePattern) + assert isinstance(t.fields[0], WildcardPattern) + assert isinstance(t.fields[1], DataflowVarPattern) + assert t.match(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")])) + assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.GlobalVar("y")])) + assert not t.match(rx.Tuple([])) + assert t[0].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)) + assert t[1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1)) + # Negative index is also allowed + assert t[-1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1)) + # None means any index. + assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)) + assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1)) + with pytest.raises(IndexError): + t[2] # index cannot be greater than or equal to the tuple size. + + +def test_unordered_tuple_pattern(): + t = is_tuple([is_const(), is_dfv()], unordered=True) + assert isinstance(t, UnorderedTuplePattern) + assert isinstance(t.fields[0], ConstantPattern) + assert isinstance(t.fields[1], DataflowVarPattern) + assert t.match(rx.Tuple([rx.const([]), rx.DataflowVar("x")])) + assert t.match(rx.Tuple([rx.DataflowVar("x"), rx.const([])])) + assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.DataflowVar("y")])) + assert not t.match(rx.Tuple([])) + + +def test_tuple_get_item_pattern(): + assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match( + rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0) + ) + assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match( + rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0) + ) + + +def test_or_pattern(): + dfv_or_gv = is_dfv("x") | is_gv("x") + assert isinstance(dfv_or_gv, OrPattern) + assert dfv_or_gv.match(rx.DataflowVar("x")) + assert dfv_or_gv.match(rx.GlobalVar("x")) + assert not dfv_or_gv.match(rx.Var("x")) + assert not dfv_or_gv.match(rx.DataflowVar("y")) + assert not dfv_or_gv.match(rx.GlobalVar("y")) + + +def test_and_pattern(): + # float[2, 3, 3] + f32_233 = wildcard().has_shape((2, 3, 3)) & has_dtype("float32") + assert isinstance(f32_233, AndPattern) + assert f32_233.match(rx.Var("x", (2, 3, 3), rx.DynTensorType(3, "float32"))) + assert not f32_233.match(rx.Var("x", (3, 3, 3), rx.DynTensorType(3, "float32"))) + assert not f32_233.match(rx.Var("x", rx.RuntimeDepShape(), rx.DynTensorType(3, "float32"))) + + +def test_not_pattern(): + no_shape233 = ~wildcard().has_shape((2, 3, 3)) + assert isinstance(no_shape233, NotPattern) + assert no_shape233.match(rx.Var("x", (3, 3, 3), rx.DynTensorType(3, "float32"))) + assert not no_shape233.match(rx.Var("x", (2, 3, 3), rx.DynTensorType(3, "float32"))) + + +def test_type_pattern(): + assert wildcard().has_type(rx.DynTensorType(2, "float32")).match(bindings[0].var) + + +def test_dtype_pattern(): + dtype = "float16" + pattern = has_dtype(dtype) + assert isinstance(pattern, DataTypePattern) + assert pattern.dtype == dtype + assert has_dtype("float32").match(bindings[0].var) + + +def test_shape_pattern(): + shape = [32, 32] + pattern = wildcard().has_shape(shape) + assert isinstance(pattern, ShapePattern) + tvm.ir.structural_equal(pattern.shape, shape) + assert pattern.match(bindings[0].var) + assert wildcard().has_shape([32, 32]).match(bindings[0].var) + n, m = tir.Var("n", dtype="int32"), tir.Var("m", dtype="int32") + symbolic_shape = rx.ShapeExpr([n, m, n + m]) + symsh_var = rx.Var("x", symbolic_shape, rx.DynTensorType(3, "float32")) + assert wildcard().has_shape([n, m, n + m]).match(symsh_var) + assert wildcard().has_shape([n, m, m + n]).match(symsh_var) # + is commutative. + assert not wildcard().has_shape([1, 2, 3]).match(symsh_var) + assert not wildcard().has_shape([m, n, n + m]).match(symsh_var) + + +def test_prim_arr_pattern(): + """ + The difference between is_shape and has_shape is that: + 1) is_shape directly matches a shape (e.g., as an argument); + 2) has_shape matches a tensor and puts assumptions on the tensor's shape. + """ + pattern = is_shape([32, 32]) + assert pattern[0] == 32 + assert pattern[1] == 32 + assert isinstance(pattern, PrimArrPattern) + assert pattern.match(bindings[0].var.shape) + n, m = tir.Var("n", dtype="int32"), tir.Var("m", dtype="int32") + symbolic_shape = rx.ShapeExpr([n, m, n + m]) + assert is_shape([n, m, n + m]).match(symbolic_shape) + assert not is_shape([n, m, n * m]).match(symbolic_shape) + + +def test_rt_dep_shape_pattern(): + # runtime-dep-shape var + rts_var = rx.Var("rts_var", rx.RuntimeDepShape(), rx.DynTensorType(4, "float32")) + # static-shape var + ss_var = rx.Var("ss_var", rx.ShapeExpr([32, 32]), rx.DynTensorType(4, "float32")) + assert wildcard().has_rt_dep_shape().match(rts_var) + assert not wildcard().has_rt_dep_shape().match(ss_var) + + +def test_extern_fn_pattern(): + pattern = ExternFuncPattern("test.blockbuilder.nop") + assert pattern.match(rx.ExternFunc("test.blockbuilder.nop")) + + +def test_op_attr(): + ttype = rx.DynTensorType(-1, "float32") + x = rx.Var("x", type_annotation=ttype) + y = rx.Var("y", type_annotation=ttype) + conv2d = relay.nn.conv2d(x, y, kernel_size=(3, 3)) + xp = is_var("x") + yp = is_var("y") + assert is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size": [3, 3]}).match(conv2d) + assert not is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size": [4, 3]}).match(conv2d) + assert not is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size_": [3, 3]}).match(conv2d) + + +def test_match_call_attr(): + ttype = rx.DynTensorType(-1, "float32") + x = rx.Var("x", type_annotation=ttype) + y = rx.Var("y", type_annotation=ttype) + fn = rx.Function([x, y], rx.op.add(x, y), ret_type=ttype, ret_shape=rx.RuntimeDepShape()) + annotated_fn = fn.with_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"}) + xp = is_var("x") + yp = is_var("y") + root_pattern = FunctionPattern([xp, yp], is_op("relax.add")(xp, yp)) + assert root_pattern.has_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"}).match( + annotated_fn + ) + + assert root_pattern.has_attr({"Codegen": "test-codegen"}).match(annotated_fn) + assert not root_pattern.has_attr({"ping": "pong"}).match(annotated_fn) + assert root_pattern.has_attr({}).match(annotated_fn) + + +def test_is_call_tir(): + lv1_val = bindings[1].value + var2val = get_var2val(Module["main"]) + assert is_call_tir("tir_relu").match(lv1_val) + assert is_call_tir("tir_relu", is_call_tir("tir_matmul")).match(lv1_val, var2val=var2val) + assert not is_call_tir("tir_relu", is_call_tir("tir_relu")).match(lv1_val, var2val=var2val) + + +@R.function +def simple_call_packed(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: + gv0 = R.call_packed("test.vm.mul", x, w, type_args=(Tensor(ndim=2, dtype="float32"))) + return gv0 + + +def test_varg_default_wildcard(): + expr = simple_call_packed.body.blocks[0].bindings[0].value + yes_pattern_explicit = ExternFuncPattern("test.vm.mul")(wildcard(), wildcard()) + yes_pattern_implicit = ExternFuncPattern("test.vm.mul")(varg_default_wildcard=True) + no_pattern = ExternFuncPattern("test.vm.mul")(wildcard()) + + assert yes_pattern_explicit.match(expr) + assert yes_pattern_implicit.match(expr) + assert not no_pattern.match(expr) + + +def test_simple_call_packed(): + expr = simple_call_packed.body.blocks[0].bindings[0].value + assert is_call_packed("test.vm.mul").match(expr) + assert is_call_packed("test.vm.mul", [is_var("x"), is_var("w")]).match(expr) + + +## Graph-wise Matching +def test_simple_used_by(): + with PatternContext() as ctx: + n0 = is_var("x") # x is a free var (fn arg) + n1 = wildcard() + n0 ^ n1 + dfb = main_fn.body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + assert matched[n0] == main_fn.params[0] + assert matched[n1] == dfb.bindings[0].var + + +def test_simple_call_tir_edge(): + with PatternContext() as ctx: + n0 = is_call_tir("tir_matmul") + n1 = is_call_tir("tir_relu") + n0.used_by(n1) + dfb = main_fn.body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + assert matched[n0] == dfb.bindings[0].var + assert matched[n1] == dfb.bindings[1].var + + +def test_simple_oub(): + with PatternContext() as ctx: + n0 = is_call_tir("tir_matmul") + n1 = is_call_tir("tir_relu") + n0 >> n1 + dfb = main_fn.body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + assert matched[n0] == dfb.bindings[0].var + assert matched[n1] == dfb.bindings[1].var + + +def test_counter_syntax_match(): + with PatternContext() as ctx: + n0 = is_call_tir("tir_matmul") + n1 = is_call_tir("tir_impossible") + n0 >> n1 + dfb = main_fn.body.blocks[0] + assert not ctx.match_dfb(dfb) + + with PatternContext() as ctx: + n0 = is_call_tir("tir_matmul") + n1 = is_call_tir("tir_impossible") + n0 ^ n1 + dfb = main_fn.body.blocks[0] + assert not ctx.match_dfb(dfb) + + +@tvm.script.ir_module +class Diamond: + @R.function + def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + # matmul + # / \ + # relu sigmoid + # \ / + # add + lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") + lv1 = R.call_tir(tir_relu, (lv0,), (32, 32), dtype="float32") + lv2 = R.call_tir(tir_sigmoid, (lv0), (32, 32), dtype="float32") + lv3 = R.call_tir(tir_add, (lv1, lv2), (32, 32), dtype="float32") + R.output(lv3) + return lv3 + + +def test_diamond(): + with PatternContext() as ctx: + n0 = is_call_tir("tir_matmul") + n1 = is_call_tir("tir_relu") + n2 = is_call_tir("tir_sigmoid") + n3 = is_call_tir("tir_add") + + n0 ^ n1 + n0 ^ n2 + n1 >> n3 + n2 >> n3 + + dfb = Diamond["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + # simplify it with fork_to + with PatternContext() as ctx: + n1 = is_call_tir("tir_relu") + n2 = is_call_tir("tir_sigmoid") + n3 = is_call_tir("tir_add") + + is_call_tir("tir_matmul").fork_to(n1, n2) + n1 >> n3 + n2 >> n3 + + dfb = Diamond["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + +def test_diamond_counter_oub(): + with PatternContext() as ctx: + n0 = is_call_tir("tir_matmul") + n1 = is_call_tir("tir_relu") + n2 = is_call_tir("tir_sigmoid") + n3 = is_call_tir("tir_add") + + n0 >> n1 + n0 >> n2 + n1 >> n3 + n2 >> n3 + + dfb = Diamond["main"].body.blocks[0] + assert not ctx.match_dfb(dfb) + + +@tvm.script.ir_module +class SmallDiamond: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + # relu + # / \ + # \ / + # add + lv0 = R.call_tir(my_relu, (x,), (32, 32), dtype="float32") + lv1 = R.call_tir(my_add, (lv0, lv0), (32, 32), dtype="float32") + R.output(lv1) + return lv1 + + +@tvm.script.ir_module +class SmallParallel: + @R.function + def main(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + # relu relu + # \ / + # add + lv0 = R.call_tir(my_relu, (x,), (32, 32), dtype="float32") + lv1 = R.call_tir(my_relu, (x,), (32, 32), dtype="float32") + lv2 = R.call_tir(my_add, (lv0, lv1), (32, 32), dtype="float32") + R.output(lv2) + return lv2 + + +def test_distiguish_diamond_and_parallel(): + # relay pattern lang cannot distinguish the two cases above. + diamond = SmallDiamond["main"].body.blocks[0] + parallel = SmallParallel["main"].body.blocks[0] + + with PatternContext() as ctx: + # describe a diamond pattern + fork = is_call_tir("my_relu") + join = is_call_tir("my_add") + fork.only_used_by(join, index=0) + fork.only_used_by(join, index=1) + + assert ctx.match_dfb(diamond) + assert not ctx.match_dfb(parallel) + + with PatternContext() as ctx: + # describe a parallel pattern + join = is_call_tir("my_add") + # Due to one-one mathcing: + # is_call_tir("my_relu") creates the 1st relu + is_call_tir("my_relu") >> join + # is_call_tir("my_relu") creates the another different relu (obj address is different) + is_call_tir("my_relu") >> join + + assert ctx.match_dfb(parallel) + assert not ctx.match_dfb(diamond) + + +@tvm.script.ir_module +class CBRx2: + @R.function + def main( + x: Tensor((32, 32), "float32"), + w0: Tensor((1, 1), "float32"), + bias0: Tensor((32, 32), "float32"), + w1: Tensor((1, 1), "float32"), + bias1: Tensor((32, 32), "float32"), + ) -> Tensor: + # TensorRT's CBR Optimization Pattern + # input + # / \ + # cbr0 cbr1 + # \ / + # concat + with R.dataflow(): + lv0 = R.call_tir(conv1x1, (x, w0), (32, 32), dtype="float32") + lv1 = R.call_tir(bias_add, (lv0, bias0), (32, 32), dtype="float32") + lv2 = R.call_tir(my_relu, (lv1), (32, 32), dtype="float32") + lv3 = R.call_tir(conv1x1, (x, w1), (32, 32), dtype="float32") + lv4 = R.call_tir(bias_add, (lv3, bias1), (32, 32), dtype="float32") + lv5 = R.call_tir(my_relu, (lv4), (32, 32), dtype="float32") + lv6 = R.call_tir(concat, (lv2, lv5), (32, 64), dtype="float32") + R.output(lv6) + return lv6 + + +def test_single_cbr(): + with PatternContext() as ctx: + is_call_tir("conv1x1") >> is_call_tir("bias_add") >> is_call_tir("my_relu") + dfb = CBRx2["main"].body.blocks[0] + matched = ctx.match_dfb(dfb) + assert matched + + with PatternContext() as ctx: + chain = is_call_tir("conv1x1") >> is_call_tir("bias_add") >> is_call_tir("my_relu") + dfb = CBRx2["main"].body.blocks[0] + # we want to specifically match the first CBR (lv0) + matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var) + assert matched + assert matched[chain[0]] == dfb.bindings[0].var + # we want to specifically match the second CBR (lv3) + matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[3].var) + assert matched + assert matched[chain[0]] == dfb.bindings[3].var + + +def test_counter_single_crb(): + with PatternContext() as ctx: + is_call_tir("conv1x1") >> is_call_tir("my_relu") >> is_call_tir("bias_add") + dfb = CBRx2["main"].body.blocks[0] + assert not ctx.match_dfb(dfb) + # Quickly fails unpromising matches by assumiung `start_hint` must be matched by a pattern. + # This is usually faster than the full match: + # Full match: let one pattern to match -> all Var: complexity ~ #Var + # must_include_hint: let `start_hint` to match -> all patterns: complexity ~ #patterns + # Usually #patterns is much smaller than #Var, so this is faster. + assert not ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var, must_include_hint=True) + + +def test_nested_context(): + dfb = CBRx2["main"].body.blocks[0] + with PatternContext() as ctx0: + is_call_tir("conv1x1") >> is_call_tir("bias_add") >> is_call_tir("my_relu") + with PatternContext() as ctx1: + is_call_tir("conv1x1") >> is_call_tir("my_relu") # pattern to miss + with PatternContext() as ctx2: + is_call_tir("bias_add") >> is_call_tir("my_relu") + assert ctx2.match_dfb(dfb) + assert PatternContext.current() == ctx2 + assert not ctx1.match_dfb(dfb) + assert PatternContext.current() == ctx1 + assert ctx0.match_dfb(dfb) + assert PatternContext.current() == ctx0 + + +def test_two_cbr(): + with PatternContext() as ctx: + cbr0 = is_call_tir("conv1x1") >> is_call_tir("bias_add") >> is_call_tir("my_relu") + cbr1 = cbr0.dup() + + assert cbr0.patterns[0] != cbr1.patterns[0] + assert cbr0.patterns[1] != cbr1.patterns[1] + assert cbr0.patterns[2] != cbr1.patterns[2] + + is_var("x").fork_to(cbr0, cbr1) + dfb = CBRx2["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + # Deny the pattern + cbr0 = is_call_tir("conv1x1") >> is_call_tir("bias_add") >> is_call_tir("my_relu") + cbr1 = cbr0.dup() + + # input has no fork at y. + is_var("y").fork_to(cbr0, cbr1) + dfb = CBRx2["main"].body.blocks[0] + assert not ctx.match_dfb(dfb) + + +def test_two_matmul(): + # Same as Figure 2(a) in TASO paper. + @tvm.script.ir_module + class MatMul2: + @R.function + def main( + a: Tensor((32, 16), "float32"), + b: Tensor((16, 48), "float32"), + c: Tensor((48, 32), "float32"), + ) -> Tensor: + with R.dataflow(): + lv0 = R.call_tir(matmul, (a, b), (32, 48), dtype="float32") + lv1 = R.call_tir(matmul, (lv0, c), (32, 32), dtype="float32") + relax.output(lv1) + return lv1 + + with PatternContext() as ctx: + is_call_tir("matmul") >> is_call_tir("matmul") + dfb = MatMul2["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + is_call_tir("matmul").has_shape([32, 48]) >> is_call_tir("matmul").has_shape([32, 32]) + dfb = MatMul2["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + is_call_tir("matmul") >> is_call_tir("matmul") >> is_call_tir("matmul") + dfb = MatMul2["main"].body.blocks[0] + # Three MatMul cannot match + assert not ctx.match_dfb(dfb) + + +def test_concat_mm_split(): + # Same as Figure 2(b) in TASO paper. + @tvm.script.ir_module + class CMS: + @R.function + def main( + a: Tensor((32, 32), "float32"), + b: Tensor((16, 32), "float32"), + c: Tensor((16, 32), "float32"), + ) -> Tensor: + with R.dataflow(): + lv0 = R.call_tir(my_concat, (b, c), (32, 32), dtype="float32") + lv1 = R.call_tir(my_matmul, (a, lv0), (32, 32), dtype="float32") + lv2 = R.call_tir( + my_split, + (lv1,), + ((16, 32), (16, 32)), + dtype=("float32", "float32"), + ) + lv3 = R.TupleGetItem(lv2, 0) + lv4 = R.TupleGetItem(lv2, 1) + lv5 = R.add(lv3, lv4) + R.output(lv5) + return lv5 + + with PatternContext() as ctx: + is_call_tir("my_concat") >> is_call_tir("my_matmul") >> is_call_tir("my_split") + dfb = CMS["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + with PatternContext() as ctx: + split = is_call_tir("my_split") + lv3 = TupleGetItemPattern(split, 0).has_shape([16, 32]) + lv4 = TupleGetItemPattern(split, 1).has_shape([16, 32]) + split.fork_to(lv3, lv4) + add = is_op("relax.add")(lv3, lv4) + # TODO(@ganler): simplify this through implicit graph pattern. + lv3 >> add + lv4 >> add + + dfb = CMS["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + +def test_self_attention(): + # The example comes from. + # https://developer.nvidia.com/blog/nlu-with-tensorrt-bert/ + @tvm.script.ir_module + class SelfAttention: + @R.function + def main( + x: Tensor((b, s, n, h), "float32"), + wq: Tensor((h, h), "float32"), + wk: Tensor((h, h), "float32"), + wv: Tensor((h, h), "float32"), + ) -> Tensor: + with R.dataflow(): + fcq = R.call_tir(my_fc, (x, wq), (b, s, n, h), dtype="float32") + tpq = R.call_tir(my_transpose, (fcq,), (b, s, h, n), dtype="float32") + + fck = R.call_tir(my_fc, (x, wk), (b, s, n, h), dtype="float32") + tpk = R.call_tir(my_transpose, (fck,), (b, s, h, n), dtype="float32") + + mul = R.multiply(tpq, tpk) + scale = R.multiply(mul, R.const(1.1, "float32")) + softmax = R.call_tir(softmax, (scale,), (b, s, n, h), dtype="float32") + + fcv = R.call_tir(my_fc, (x, wv), (b, s, n, h), dtype="float32") + tpv = R.call_tir(my_transpose, (fcv,), (b, s, h, n), dtype="float32") + + out = R.multiply(softmax, tpv) + R.output(out) + + return out + + with PatternContext() as ctx: + fc_trans_q = is_call_tir("my_fc") >> is_call_tir("my_transpose") + fc_trans_k = fc_trans_q.dup() + fc_trans_v = fc_trans_q.dup() + + is_var("x").fork_to(fc_trans_q, fc_trans_k, fc_trans_v) + dfb = SelfAttention["main"].body.blocks[0] + assert ctx.match_dfb(dfb) + + +def test_nested_diamond(): + @tvm.script.ir_module + class DiamondInDiamond: + @R.function + def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + # matmul0 matmul1 + # / \ / \ + # sigmoid2 add4 sigmoid3 + # \ / \ / + # add5 add6 + # \ / + # add7 + lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") + lv1 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") + lv2 = R.call_tir(tir_sigmoid, (lv0), (32, 32), dtype="float32") + lv3 = R.call_tir(tir_sigmoid, (lv1), (32, 32), dtype="float32") + lv4 = R.call_tir(tir_add, (lv0, lv1), (32, 32), dtype="float32") + lv5 = R.call_tir(tir_add, (lv2, lv4), (32, 32), dtype="float32") + lv6 = R.call_tir(tir_add, (lv3, lv4), (32, 32), dtype="float32") + lv7 = R.call_tir(tir_add, (lv5, lv6), (32, 32), dtype="float32") + R.output(lv7) + return lv7 + + # match matmul0 diamond + with PatternContext() as ctx: + sigmoid2 = is_call_tir("tir_sigmoid") + add4 = is_call_tir("tir_add") + is_call_tir("tir_matmul").fork_to(sigmoid2, add4) + add5 = is_call_tir("tir_add") + sigmoid2 >> add5 + add4 ^ add5 + assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + # counter case: mis-match matmul0 diamond + with PatternContext() as ctx: + sigmoid2 = is_call_tir("tir_sigmoid") + add4 = is_call_tir("tir_add") + is_call_tir("tir_matmul").fork_to(sigmoid2, add4) + add5 = is_call_tir("tir_add") + sigmoid2 >> add5 + add4 >> add5 # not only-used-by relation + assert not ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + # match matmul1 diamond + with PatternContext() as ctx: + sigmoid3 = is_call_tir("tir_sigmoid") + add4 = is_call_tir("tir_add") + is_call_tir("tir_matmul").fork_to(sigmoid3, add4) + add6 = is_call_tir("tir_add") + sigmoid3 >> add6 + add4 ^ add6 + assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + # match add-4-5-6-7 + with PatternContext() as ctx: + add5, add6, add7 = is_call_tir("tir_add"), is_call_tir("tir_add"), is_call_tir("tir_add") + is_call_tir("tir_add").fork_to(add5, add6) # add4 + add5 >> add7 + add6 >> add7 + assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) + + +def test_incremental_solving(): + @R.function + def simple_chain(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + # relu -> sigmoid -> neg + lv0 = R.call_tir(tir_relu, (x), (32, 32), dtype="float32") + lv1 = R.call_tir(tir_sigmoid, (lv0), (32, 32), dtype="float32") + lv2 = R.call_tir(tir_neg, (lv1), (32, 32), dtype="float32") + R.output(lv2) + return lv2 + + relu = is_call_tir("tir_relu") + sigmoid = is_call_tir("tir_sigmoid") + neg = is_call_tir("tir_neg") + + with PatternContext() as ctx0: + relu >> sigmoid + with PatternContext(incremental=True) as ctx1: + # because we are doing incremental solving + # relu >> sigmoid is still a constraint in this context. + # that said the total constraint is: + # relu >> sigmoid >> neg + sigmoid >> neg + assert ctx1.match_dfb(simple_chain.body.blocks[0]) + + # match relue -> sigmoid + assert ctx0.match_dfb(simple_chain.body.blocks[0]) + + +def test_incremental_solving_counter(): + @R.function + def simple_chain(x: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + # sigmoid -> neg + lv0 = R.call_tir(tir_sigmoid, (x), (32, 32), dtype="float32") + lv1 = R.call_tir(tir_neg, (lv0), (32, 32), dtype="float32") + R.output(lv1) + return lv1 + + relu = is_call_tir("tir_relu") + sigmoid = is_call_tir("tir_sigmoid") + neg = is_call_tir("tir_neg") + + with PatternContext() as ctx0: + relu >> sigmoid # cannot match + + with PatternContext(incremental=False) as ctx1: + # total constraint: sigmoid >> neg + sigmoid >> neg + assert ctx1.match_dfb(simple_chain.body.blocks[0]) + + with PatternContext(incremental=True) as ctx1: + # total constraint: relu >> sigmoid >> neg + sigmoid >> neg + assert not ctx1.match_dfb(simple_chain.body.blocks[0]) diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py new file mode 100644 index 0000000000..c6eac84df2 --- /dev/null +++ b/tests/python/relax/test_expr.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +from tvm import tir +from tvm import relax as rx +import numpy as np + + +def test_var() -> None: + v0 = rx.Var("v0") + assert v0.name_hint == "v0" + assert v0.shape_ is None + assert v0._checked_type_ is None + shape_anno = [54, 96] + type_anno = rx.DynTensorType(2, "float32") + v1 = rx.Var("v1", shape_anno, type_anno) + assert v1.name_hint == "v1" + for s0, s1 in zip(v1.shape_, shape_anno): + assert s0 == s1 + assert v1._checked_type_ == type_anno + + +def test_dataflow_var() -> None: + v0 = rx.DataflowVar("v0") + assert v0.name_hint == "v0" + assert v0.shape_ is None + assert v0._checked_type_ is None + shape_anno = [54, 96] + type_anno = rx.DynTensorType(2, "float16") + v1 = rx.DataflowVar("v1", shape_anno, type_anno) + assert v1.name_hint == "v1" + for s0, s1 in zip(v1.shape_, shape_anno): + assert s0 == s1 + assert v1._checked_type_ == type_anno + assert isinstance(v1, rx.DataflowVar) + + +def test_match_shape() -> None: + # match_shape([16, 8], [m, n]) + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + shape = rx.const([16, 8], "int32") + var = rx.Var("v0", type_annotation=rx.ShapeType()) + b0 = rx.MatchShape(shape, [m, n], var) + assert b0.value == shape + assert b0.pattern[0] == m + assert b0.pattern[1] == n + assert b0.var is not None + assert b0.var.checked_type == rx.ShapeType() + + # var1: Tensor((m, n), "float32") = + # match_shape(var0: Tensor(_, "float32"), [m, n]) + type_anno0 = rx.DynTensorType(-1, "float32") + value = rx.Var("value", type_annotation=type_anno0) + + shape_anno = [m, n] + type_anno = rx.DynTensorType(2, "float32") + var = rx.Var("v1", shape_anno, type_anno) + b1 = rx.MatchShape(value, [m, n], var) + assert b1.value == value + assert b1.pattern[0] == m + assert b1.pattern[1] == n + assert b1.var is not None + for s0, s1 in zip(b1.var.shape, [m, n]): + assert s0 == s1 + assert b1.var.checked_type == rx.DynTensorType(2, "float32") + + +def test_var_binding() -> None: + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b0 = rx.VarBinding(v0, val) + assert b0.var.name_hint == "v0" + assert b0.value == val + + +def test_binding_block() -> None: + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchShape(shape, [m, n], rx.Var("v0")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.BindingBlock([b0, b1]) + assert block0.bindings[0] == b0 + assert block0.bindings[1] == b1 + + +def test_dataflow_block() -> None: + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchShape(shape, [m, n], rx.Var("v0")) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.DataflowBlock([b0, b1]) + assert block0.bindings[0] == b0 + assert block0.bindings[1] == b1 + assert isinstance(block0, rx.DataflowBlock) + + +def test_seq_expr() -> None: + x = rx.Var("foo") + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) + assert seqe.blocks[0] == blocks[0] + assert seqe.body == x + + +def test_shape_expr() -> None: + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + s = rx.ShapeExpr([m, n]) + assert s.values[0] == m + assert s.values[1] == n + + +def test_func(): + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("foo", type_annotation=type_anno) + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) + ret_type = rx.DynTensorType(-1, "float32") + ret_shape = rx.RuntimeDepShape() + func = rx.Function([x], seqe, ret_type, ret_shape) + func = func.with_attr("global_symbol", "func") + assert func.params[0] == x + assert func.body == seqe + assert func.ret_type == ret_type + assert func.ret_shape == ret_shape + assert func.attrs["global_symbol"] == "func" + + +def test_shape_of(): + v0 = rx.Var("v0") + s0 = v0.shape + assert isinstance(s0, tvm.relay.Call) + assert s0.op.name == "relax.shape_of" + + shape_anno = [96, 54] + v1 = rx.Var("v1", shape_anno) + s1 = v1.shape + for x, y in zip(shape_anno, s1): + assert x == y + + +def test_shape_expr(): + shape_expr = rx.ShapeExpr([10, 20]) + assert shape_expr.values[0] == 10 + assert shape_expr.values[1] == 20 + assert shape_expr.checked_type == rx.ShapeType() + assert shape_expr.shape_ is None + + x = rx.Var("v0", (10, 20), rx.DynTensorType(2, "float32")) + assert x.shape_.values[0] == 10 + assert x.shape_.values[1] == 20 + assert x.shape_.checked_type == rx.ShapeType() + assert x.shape_.shape_ is None + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py new file mode 100644 index 0000000000..debd82b3ae --- /dev/null +++ b/tests/python/relax/test_expr_functor.py @@ -0,0 +1,777 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import relax, tir +from tvm.relax import PyExprVisitor, PyExprMutator +from tvm.ir.base import assert_structural_equal +from tvm.ir import Op +from tvm.relax.ty import DynTensorType +from tvm.relax.expr import Type, Span, Expr +from tvm.relax.expr import Function, ExternFunc +from tvm.relax.expr import Constant, Var, DataflowVar +from tvm.relax.expr import ShapeExpr, RuntimeDepShape +from tvm.relax.expr import GlobalVar, SeqExpr, Tuple +from tvm.relax.expr import Call, If, TupleGetItem +from tvm.relax.expr import Binding, MatchShape, VarBinding +from tvm.relax.expr import BindingBlock, DataflowBlock +from tvm.relax.expr import _update_shape, _update_type + +m, n = tir.Var("m", "int64"), tir.Var("n", "int64") +type_anno1 = relax.DynTensorType(1, "float32") +type_anno2 = relax.DynTensorType(2, "float32") +x = relax.Var("x", [n], type_anno1) +y = relax.Var("y", [m, n], type_anno2) +bb = relax.BlockBuilder() + + +@relax.expr_functor.visitor +class BasicVisitor(PyExprVisitor): + """Default ExprVisitor""" + + +class ASTLog: + """Helper class to log AST""" + + def __init__(self) -> None: + self.log = [] + self.indent = "\t" + self.level = 0 + + def push_scope(self): + self.level += 1 + + def pop_scope(self): + self.level -= 1 + + def add(self, s: str): + self.log.append(self.indent * self.level + s) + + def __str__(self) -> str: + return "\n".join(self.log) + + +@relax.expr_functor.visitor +class ASTPrinter(PyExprVisitor): + """Print relax AST in structured format. The shape of Node is ignored.""" + + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_constant_(self, op: Constant) -> None: + self.log.add("Constant") + + def visit_global_var_(self, op: GlobalVar) -> None: + self.log.add("GlobalVar") + + def visit_tuple_(self, op: Tuple) -> None: + self.log.add("Tuple") + self.log.push_scope() + for field in op.fields: + self.visit_expr(field) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_dataflow_var_(self, op: DataflowVar) -> None: + self.log.add("DataflowVar") + + def visit_function_(self, op: Function) -> None: + self.log.add("Function") + self.log.push_scope() + for param in op.params: + self.visit_var_def(param) + + self.visit_expr(op.body) + self.log.pop_scope() + + def visit_call_(self, op: Call) -> None: + self.log.add("Call") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_if_(self, op: If) -> None: + self.log.add("If") + self.log.push_scope() + self.visit_expr(op.cond) + self.visit_expr(op.true_branch) + self.visit_expr(op.false_branch) + self.log.pop_scope() + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + def visit_tuple_getitem_(self, op: TupleGetItem) -> None: + self.log.add("TupleGetItem") + self.log.push_scope() + self.visit_expr(op.tuple_value) + self.log.pop_scope() + + def visit_shape_expr_(self, op: ShapeExpr) -> None: + self.log.add("ShapeExpr") + + def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> None: + self.log.add("RuntimeDepShape") + + def visit_extern_func_(self, op: ExternFunc) -> None: + self.log.add("ExternFunc") + + def visit_seq_expr_(self, op: SeqExpr) -> None: + self.log.add("SeqExpr") + self.log.push_scope() + for block in op.blocks: + self.visit_binding_block(block) + self.visit_expr(op.body) + self.log.pop_scope() + + def visit_var_binding_(self, binding: VarBinding) -> None: + self.log.add("VarBinding") + self.log.push_scope() + self.visit_expr(binding.value) + self.visit_var_def(binding.var) + self.log.pop_scope() + + def visit_match_shape_(self, binding: MatchShape) -> None: + self.log.add("MatchShape") + self.log.push_scope() + self.visit_expr(binding.value) + self.visit_expr(ShapeExpr(binding.pattern)) + if binding.var: + self.visit_var_def(binding.var) + self.log.pop_scope() + + def visit_binding_block_(self, block: BindingBlock) -> None: + self.log.add("BindingBlock") + self.log.push_scope() + for binding in block.bindings: + self.visit_binding(binding) + self.log.pop_scope() + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + self.log.add("DataflowBlock") + self.log.push_scope() + for binding in block.bindings: + self.visit_binding(binding) + self.log.pop_scope() + + def visit_var_def_(self, var: Var) -> None: + self.log.add("VarDef") + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + self.log.add("DataflowVarDef") + + +@relax.expr_functor.mutator +class BasicMutator(PyExprMutator): + """Default ExprMutator""" + + +@relax.expr_functor.mutator +class ASTPostPrinterMutator(PyExprMutator): + """Print relax AST in the post order format.""" + + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_constant_(self, op: Constant) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Constant") + return op + + def visit_global_var_(self, op: GlobalVar) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("GlobalVar") + return op + + def visit_tuple_(self, op: Tuple) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Tuple") + return op + + def visit_var_(self, op: Var) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Var") + return op + + def visit_dataflow_var_(self, op: DataflowVar) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("DataflowVar") + return op + + def visit_function_(self, op: Function) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Function") + return op + + def visit_call_(self, op: Call) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Call") + return op + + def visit_if_(self, op: If) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("If") + return op + + def visit_op_(self, op: Op) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Op") + return op + + def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("TupleGetItem") + return op + + def visit_shape_expr_(self, op: ShapeExpr) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("ShapeExpr") + return op + + def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("RuntimeDepShape") + return op + + def visit_extern_func_(self, op: ExternFunc) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("ExternFunc") + return op + + def visit_seq_expr_(self, op: SeqExpr) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("SeqExpr") + return op + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Identical with ExprMutator::VisitBinding_(const VarBindingNode* binding) on the C++ side.""" + new_value = self.visit_expr(binding.value) + new_var = self.visit_var_def(binding.var) + + def emit(b: VarBinding): + if self.builder_.current_block_is_dataflow() and not isinstance(b.var, DataflowVar): + self.builder_.emit_output_var_binding(b) + else: + self.builder_.emit_var_binding(b) + + self.log.add("VarBinding") + if binding.var.same_as(new_var) and binding.value.same_as(new_value): + emit(binding) + return + + temp = self.with_shape_and_type(new_var, new_value.shape_, new_value._checked_type_) + if not temp.same_as(new_var): + new_var = temp + self.set_var_remap(binding.var.vid, new_var) + + emit(VarBinding(new_var, new_value)) + + def visit_match_shape_(self, binding: MatchShape) -> None: + """Identical with ExprMutator::VisitBinding_(const MatchShapeNode* binding) on the C++ side.""" + new_value = self.visit_expr(binding.value) + new_pattern = self.visit_expr(ShapeExpr(binding.pattern)) + + if binding.var: + new_shape = None + if new_value._checked_type_ and isinstance(new_value._checked_type_, DynTensorType): + new_shape = new_pattern + new_var = self.visit_var_def(binding.var) + temp = self.with_shape_and_type(new_var, new_shape, new_value._checked_type_) + if not temp.same_as(new_var): + new_var = temp + self.set_var_remap(binding.var.vid, new_var) + + self.log.add("MatchShape") + if binding.value.same_as(new_value) and binding.pattern.same_as(new_pattern): + if not binding.var or (binding.var and binding.var.same_as(new_var)): + self.builder_.match_shape_binding(binding) + return + + self.builder_.match_shape_binding(MatchShape(new_value, new_pattern.values, new_var)) + + def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: + """Identical with ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) on the C++ side.""" + self.builder_._begin_binding_block() + for binding in block.bindings: + self.visit_binding(binding) + self.log.add("BindingBlock") + return self.builder_._end_block() + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + """Identical with ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) on the C++ side.""" + self.builder_._begin_dataflow_block() + for binding in block.bindings: + self.visit_binding(binding) + self.log.add("DataflowBlock") + return self.builder_._end_block() + + def visit_var_def_(self, var: Var) -> None: + """Identical with ExprMutator::VisitVarDef_(const VarNode* var) on the C++ side.""" + shape_unchanged = True + new_shape = None + if var.shape_: + new_shape = self.visit_expr(var.shape_) + shape_unchanged &= var.shape_.same_as(new_shape) + + self.log.add("VarDef") + if shape_unchanged: + return var + else: + new_var = Var(var.vid, None, var._checked_type_, var.span) + _update_shape(new_var, new_shape) + + self.set_var_remap(var.vid, new_var) + return new_var + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + """Identical with ExprMutator::VisitVarDef_(const DataflowVarNode* var) on the C++ side.""" + shape_unchanged = True + new_shape = None + if var.shape_: + new_shape = self.visit_expr(var.shape_) + shape_unchanged &= var.shape_.same_as(new_shape) + + self.log.add("DataflowVarDef") + if shape_unchanged: + return var + else: + new_var = DataflowVar(var.vid, None, var._checked_type_, var.span) + _update_shape(new_var, new_shape) + + self.set_var_remap(var.vid, new_var) + return new_var + + +def basic_check(expr, visitor_str, mutator_str): + def visit(f, expr): + if isinstance(expr, relax.Expr): + return f.visit_expr(expr) + elif isinstance(expr, relax.BindingBlock): + return f.visit_binding_block(expr) + + # check no overloading case + basic_visitor = BasicVisitor() + visit(basic_visitor, expr) + + # check the output log + log_visitor = ASTPrinter() + visit(log_visitor, expr) + assert str(log_visitor.log) == visitor_str + + # check no overloading case + basic_mutator = BasicMutator() + if isinstance(expr, relax.Expr): + expr = bb.normalize(expr) + assert_structural_equal(visit(basic_mutator, expr), expr) + + # check the output log and return value + post_log_mutator = ASTPostPrinterMutator() + if isinstance(expr, relax.Expr): + expr = bb.normalize(expr) + assert_structural_equal(visit(post_log_mutator, expr), expr) + assert str(post_log_mutator.log) == mutator_str + + +def test_constant(): + basic_check(relax.const(1.0), "Constant", "Constant") + + +def test_var(): + basic_check(x, "Var", "Var") + + +def test_dataflow_var(): + lv = relax.DataflowVar("lv", [n], type_anno1) + basic_check(lv, "DataflowVar", "DataflowVar") + + +def test_tuple(): + t = relax.Tuple([x, y]) + basic_check(t, "\n".join(["Tuple", "\tVar", "\tVar"]), "\n".join(["Var", "Var", "Tuple"])) + + +def test_global_var(): + gv = relax.GlobalVar("gv") + basic_check(gv, "GlobalVar", "GlobalVar") + + +def test_seq_expr(): + bindings = [relax.VarBinding(x, relax.const(1))] + blocks = [relax.BindingBlock(bindings)] + seq_expr = relax.SeqExpr(blocks, x) + basic_check( + seq_expr, + "\n".join( + [ + "SeqExpr", + "\tBindingBlock", + "\t\tVarBinding", + "\t\t\tConstant", + "\t\t\tVarDef", + "\tVar", + ] + ), + "\n".join( + ["Constant", "ShapeExpr", "VarDef", "VarBinding", "BindingBlock", "Var", "SeqExpr"] + ), + ) + + +def test_shape_expr(): + x = relax.ShapeExpr([m, n]) + basic_check(x, "ShapeExpr", "ShapeExpr") + + +def test_runtime_dep_shape(): + runtime_dep_shape = relax.RuntimeDepShape() + basic_check(runtime_dep_shape, "RuntimeDepShape", "RuntimeDepShape") + + +def test_call(): + call_node = relax.op.add(x, y) + basic_check( + call_node, + "\n".join(["Call", "\tOp", "\tVar", "\tVar"]), + "\n".join(["Op", "Var", "Var", "Call"]), + ) + + +def test_if(): + if_node = relax.If(x, x, x) + basic_check( + if_node, + "\n".join(["If", "\tVar", "\tVar", "\tVar"]), + "\n".join(["Var", "Var", "Var", "If"]), + ) + + +def test_tuple_getitem(): + tuple_getitem_node = relax.TupleGetItem(relax.Tuple([x, y]), 0) + basic_check( + tuple_getitem_node, + "\n".join(["TupleGetItem", "\tTuple", "\t\tVar", "\t\tVar"]), + "\n".join(["Var", "Var", "Tuple", "TupleGetItem"]), + ) + + +def test_binding_block(): + bb._begin_binding_block() + gv0 = bb.emit(relax.op.add(x, y)) + gv1 = bb.match_shape(y, [m, n]) + b0 = bb._end_block() + basic_check( + b0, + "\n".join( + [ + "BindingBlock", + "\tVarBinding", + "\t\tCall", + "\t\t\tOp", + "\t\t\tVar", + "\t\t\tVar", + "\t\tVarDef", + "\tMatchShape", + "\t\tVar", + "\t\tShapeExpr", + "\t\tVarDef", + ] + ), + "\n".join( + [ + "Op", + "Var", + "Var", + "Call", + "ShapeExpr", + "VarDef", + "VarBinding", + "Var", + "ShapeExpr", + "ShapeExpr", + "VarDef", + "MatchShape", + "BindingBlock", + ] + ), + ) + + +def test_dataflow_block(): + bb._begin_dataflow_block() + lv0 = bb.emit(relax.op.add(x, y)) + gv1 = bb.match_shape(y, [m, n]) + b0 = bb._end_block() + basic_check( + b0, + "\n".join( + [ + "DataflowBlock", + "\tVarBinding", + "\t\tCall", + "\t\t\tOp", + "\t\t\tVar", + "\t\t\tVar", + "\t\tDataflowVarDef", + "\tMatchShape", + "\t\tVar", + "\t\tShapeExpr", + "\t\tDataflowVarDef", + ] + ), + "\n".join( + [ + "Op", + "Var", + "Var", + "Call", + "ShapeExpr", + "DataflowVarDef", + "VarBinding", + "Var", + "ShapeExpr", + "ShapeExpr", + "DataflowVarDef", + "MatchShape", + "DataflowBlock", + ] + ), + ) + + +def test_function(): + bindings = [relax.VarBinding(x, relax.const(1))] + blocks = [relax.BindingBlock(bindings)] + seq_expr = relax.SeqExpr(blocks, x) + ret_type = relax.DynTensorType(-1, "float32") + ret_shape = relax.RuntimeDepShape() + func = relax.Function([x], seq_expr, ret_type, ret_shape) + basic_check( + func, + "\n".join( + [ + "Function", + "\tVarDef", + "\tSeqExpr", + "\t\tBindingBlock", + "\t\t\tVarBinding", + "\t\t\t\tConstant", + "\t\t\t\tVarDef", + "\t\tVar", + ] + ), + "\n".join( + [ + "ShapeExpr", + "VarDef", + "RuntimeDepShape", + "Constant", + "ShapeExpr", + "VarDef", + "VarBinding", + "BindingBlock", + "Var", + "SeqExpr", + "Function", + ] + ), + ) + + +def test_extern_func(): + func = relax.ExternFunc("f") + basic_check(func, "ExternFunc", "ExternFunc") + + +def test_inherit(): + # The internal class is not instantiated. + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + call_node = relax.op.add(x, y) + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"]) + + +def test_inherit_with_cls(): + # The decorator converts `InternalVisitor` to a wrapper class. + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + # `InternalVisitor._cls` refers to the original `InternalVisitor` users defined. + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + call_node = relax.op.add(x, y) + iv = InternalVisitor() + iv.visit_expr(call_node) + assert str(iv.log) == "\n".join(["InternalCall", "\tOp", "\tVar", "\tVar"]) + + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"]) + + +def test_wrong_inherit(): + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def visit_call_(self, op: Call) -> None: + pass + + with pytest.raises( + TypeError, + match="Inheritance from a decorated object `LeafVisitor` is not allowed. Please inherit from `LeafVisitor._cls`.", + ): + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor): + def visit_call_(self, op: Call) -> None: + pass + + +def test_call_visitor_super(): + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + super().visit_call_(op) # call PyExprVisitor.visit_call_ + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + super().visit_call_(op) # call InternalVisit.visit_call_ + + call_node = relax.op.add(x, y) + iv = InternalVisitor() + iv.visit_expr(call_node) + assert str(iv.log) == "\n".join(["InternalCall", "Op", "Var", "Var"]) + + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"]) + + +def test_call_mutator_super(): + @relax.expr_functor.mutator + class InternalMutator(PyExprMutator): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + return super().visit_call_(op) # call PyExprMutator.visit_call_ + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + return super().visit_var_(op) # call PyExprMutator.visit_var_ + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + return super().visit_op_(op) # call PyExprMutator.visit_op_ + + @relax.expr_functor.mutator + class LeafMutator(InternalMutator._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + return super().visit_call_(op) # call InternalMutator.visit_call_ + + call_node = relax.op.add(x, y) + im = InternalMutator() + im.visit_expr(call_node) + assert str(im.log) == "\n".join(["InternalCall", "Op", "Var", "Var"]) + + lm = LeafMutator() + lm.visit_expr(call_node) + assert str(lm.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"]) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_function_attr.py b/tests/python/relax/test_function_attr.py new file mode 100644 index 0000000000..671f44c789 --- /dev/null +++ b/tests/python/relax/test_function_attr.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations +import pytest +import tvm +from tvm.script import relax as R +from tvm import relax + + +def _check_equal(x, y): + tvm.ir.assert_structural_equal(x, y) + tvm.ir.assert_structural_equal(y, x) + + xhash = tvm.ir.structural_hash(x) + yhash = tvm.ir.structural_hash(y) + + assert xhash == yhash + + +def _check_save_roundtrip(x): + y = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, y) + + +@tvm.script.ir_module +class InputModule: + @R.function + def relax_add(x: Tensor((2, 3), "float32"), y: Tensor((2, 3), "float32")) -> Tensor: + z1 = relax.add(x, y) + z2 = relax.add(z1, z1) + return z2 + + @R.function + def main(x: Tensor((2, 3), "float32"), y: Tensor((2, 3), "float32")) -> Tensor: + lv0 = relax_add(x, y) + return lv0 + + +def annotate(mod, func_name, attrs): + # Get func + annot_func = mod[func_name] + # Annotate a function + for key, val in attrs.items(): + annot_func = annot_func.with_attr(key, val) + mod[func_name] = annot_func + return mod + + +def test_func_attr_setter(): + mod = InputModule + assert isinstance(mod, tvm.IRModule) + + mod = annotate(mod, "relax_add", {"Codegen": "test-codegen", "global_symbol": "test-symbol"}) + _check_save_roundtrip(mod) + annot_func = mod["relax_add"] + + # Test annotation + assert annot_func.attrs + assert annot_func.attrs["Codegen"] == "test-codegen" + assert annot_func.attrs["global_symbol"] == "test-symbol" + + +def test_func_attr_roundtrip_and_equality(): + mod = InputModule + assert isinstance(mod, tvm.IRModule) + mod1 = annotate(mod, "relax_add", {"Codegen": "test-codegen", "global_symbol": "test-symbol"}) + mod2 = annotate(mod, "relax_add", {"Codegen": "test-codegen", "global_symbol": "test-symbol"}) + _check_save_roundtrip(mod1) + _check_save_roundtrip(mod2) + _check_equal(mod1, mod2) + + +def test_func_attr_setter_with_passes(): + mod = InputModule + assert isinstance(mod, tvm.IRModule) + # Annotate + mod = annotate(mod, "relax_add", {"Codegen": "test-codegen", "global_symbol": "test-symbol"}) + + # Test with passes + # Annotation should stay the same unless the pass needs to modify it + + # List of passes + passes = [relax.transform.ToNonDataflow()] + passes.append(relax.transform.CallTIRRewrite()) + passes.append(relax.transform.VMMemoryLower()) + passes.append(relax.transform.VMShapeLower()) + seq = tvm.transform.Sequential(passes) + + # Apply passes + new_mod = seq(mod) + _check_save_roundtrip(new_mod) + + # Test annotation + func = new_mod["relax_add"] + assert func.attrs + assert func.attrs["Codegen"] == "test-codegen" + assert func.attrs["global_symbol"] == "test-symbol" + + +def test_irmodule_attr_setter_with_passes(): + mod = InputModule + assert isinstance(mod, tvm.IRModule) + # Annotate + attr = relax.const(1, "float32") + mod = mod.with_attr("test-attr", attr) + mod_attr = mod.get_attrs() + + # Test with passes + # Annotation should stay the same unless the pass needs to modify it + + # List of passes + passes = [relax.transform.ToNonDataflow()] + passes.append(relax.transform.CallTIRRewrite()) + passes.append(relax.transform.VMMemoryLower()) + passes.append(relax.transform.VMShapeLower()) + seq = tvm.transform.Sequential(passes) + + # Apply passes + new_mod = seq(mod) + _check_save_roundtrip(new_mod) + + # Check IRModule attrs is preserved after applying passes + assert new_mod.get_attrs()["test-attr"] == attr + assert new_mod.get_attrs() == mod_attr + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_op.py b/tests/python/relax/test_op.py new file mode 100644 index 0000000000..4ea6a3e332 --- /dev/null +++ b/tests/python/relax/test_op.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +from tvm import tir +from tvm import relax as rx +from tvm.script import tir as T + + +@tvm.register_func("test.op.identity") +def identity_packed(a): + return tvm.nd.array(a.asnumpy()) + + +@T.prim_func +def identity_tir(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [54, 96]) + B = T.match_buffer(b, [54, 96]) + + for i, j in T.grid(54, 96): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + + +def test_call_tir() -> None: + shape_anno = [54, 96] + type_anno = rx.DynTensorType(2, "float32") + v0 = rx.Var("v0", shape_anno, type_anno) + v1 = rx.call_tir(rx.extern("test.op.identity"), [v0], [54, 96], "float32") + v1 = rx.call_tir(identity_tir, [v0], [54, 96], "float32") + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py new file mode 100644 index 0000000000..b7f196d399 --- /dev/null +++ b/tests/python/relax/test_parser.py @@ -0,0 +1,894 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations # must import to defer parsing of annotations +import pytest +import tvm +from tvm import tir, relay, relax +from tvm.ir import assert_structural_equal + +import tvm.script +from tvm.script import tir as T, relax as R + +# TODO: replace xfails with proper diagnostics checking. +# c.f. tests/python/unittest/test_tvmscript_error_report.py + + +def check_shape(e, s): + if isinstance(e, relax.ShapeExpr): + pass + elif isinstance(e, relax.Call): + e = e.shape + elif isinstance(e, relax.Expr): + e = e.shape_ + + if s is None: + assert e is None + return + + if isinstance(s, relax.RuntimeDepShape): + assert isinstance(e, relax.RuntimeDepShape) + return + + assert len(e) == len(s) + + for edim, sdim in zip(e, s): + if isinstance(sdim, str): + assert isinstance(edim, tir.Var) + assert edim.name == sdim + else: + assert isinstance(edim, tir.IntImm) + assert edim.value == sdim + + +def check_tensor_var(v, s, d, ndim=None): + assert isinstance(v._checked_type_, relax.ty.DynTensorType) + assert v._checked_type_.dtype == d + if isinstance(s, (list, tuple)): + assert v._checked_type_.ndim == len(s) + if ndim is not None: + assert v._checked_type_.ndim == ndim + check_shape(v, s) + + +def check_call(call, op, args): + assert isinstance(call, relax.Call) + if isinstance(op, str): + op = relay.op.get(op) + assert call.op == op + assert_structural_equal(call.args, args) + + +def test_annotations(): + @R.function + def f( + x: Tensor((32, m), "float32"), + y: Tensor((m, k), "float32"), + r: Tensor(_, "int64"), + ) -> Object: + z: Tensor((32, k), "float32") = nn.matmul(x, y, units=None) + w: Tensor(None, _) = multiply(z, z) + q: Tensor(None, _, ndim=2) = add(w, w) + t = subtract(w, z) + sh: Shape = t.shape + o: Object = relax.call_packed("contrib.tensor_array_stack", x, y, type_args=(Object)) + return o + + x, y, r = f.params + z_bind, w_bind, q_bind, t_bind, sh_bind, o_bind = f.body.blocks[0].bindings + z, mm = z_bind.var, z_bind.value + w, mul = w_bind.var, w_bind.value + q, add = q_bind.var, w_bind.value + t, sub = t_bind.var, t_bind.value + sh, shape_of = sh_bind.var, sh_bind.value + o, o_call_packed = o_bind.var, o_bind.value + + check_tensor_var(x, (32, "m"), "float32") + check_tensor_var(y, ("m", "k"), "float32") + check_tensor_var(r, relax.RuntimeDepShape(), "int64") + check_tensor_var(z, (32, "k"), "float32") + check_tensor_var(w, None, "") + check_tensor_var(q, None, "", ndim=2) + assert t._checked_type_ is None + assert isinstance(sh._checked_type_, relax.ty.ShapeType) + + check_call(mm, "nn.matmul", [x, y]) + check_call(mul, "multiply", [z, z]) + check_call(sub, "subtract", [w, z]) + check_call(shape_of, "relax.shape_of", [t]) + + assert f.body.body == o + + assert isinstance(f.ret_type, relax.ty.ObjectType) + + assert isinstance(o._checked_type_, relax.ty.ObjectType) + assert len(o_call_packed.type_args) == 1 + + +def test_annotations_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor("u", "int64")): + return x + + +def test_mismatch_shape_dims_and_ndim(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor((2, 3), "float32", ndim=3)): + return x + + +def test_unexpected_num_kw_args(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor(_, "float32", ndim=1, foo=2)): + return x + + +def test_unexpected_kw_arg(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor(_, "float32", foo=1)): + return x + + +def test_unexpected_ndim(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor(_, "float32", ndim=-2)): + return x + + +def test_unexpected_ndim_type(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor(_, "float32", ndim="1")): + return x + + +def test_unexpected_tir_cast_args(): + # tir.cast expects 2 arguments, but got 3 + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor((m,), "float32")): + return relax.call_tir("foo", (x,), (tir.cast("int32", m, 1),), dtype="float32") + + +def test_unexpected_tir_max_args(): + # tir.max expects 2 arguments, but got 1 + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor((m, n), "float32")): + return relax.call_tir("foo", (x,), (tir.max(m),), dtype="float32") + + +def test_match_shape(): + @R.function + def f(x: Tensor(_, "float32")): + relax.match_shape(x.shape, (n, m)) + y: Tensor((n, m), "float32") = add(x, x) + return x + + match_sh = f.body.blocks[0].bindings[0] + pattern, value = match_sh.pattern, match_sh.value + + check_shape(pattern, ("n", "m")) + check_call(value, "relax.shape_of", [f.params[0]]) + + +def test_dim_var_intro_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor(_, _)): + y: Tensor((n, m), "float32") = x + return y + + +def test_if(): + @R.function + def f(cond: Tensor((), "bool"), x: Tensor((1,), "float32")): + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + y = add(w, w) + return y + + cond, x = f.params + y_bind = f.body.blocks[0].bindings[0] + y, ite = y_bind.var, y_bind.value + + check_tensor_var(cond, tuple(), "bool") + check_tensor_var(x, (1,), "float32") + + assert isinstance(y, relax.Var) + assert y.name_hint == "y" + + assert isinstance(ite, relax.If) + assert isinstance(ite.true_branch, relax.SeqExpr) + assert isinstance(ite.false_branch, relax.SeqExpr) + + w_bind = ite.true_branch.blocks[0].bindings[0] + body = ite.true_branch.body + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "add", [x, x]) + check_call(body, "multiply", [w_bind.var, w_bind.var]) + + w_bind = ite.false_branch.blocks[0].bindings[0] + body = ite.false_branch.body + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "multiply", [x, x]) + check_call(body, "add", [w_bind.var, w_bind.var]) + + +# TODO: figure out if-else binding type and shape + + +def test_var_redefine_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x, y): + z = add(x, y) + y = z + return y + + +def test_var_redefine_fail_if(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(cond: Tensor((), "bool"), x: Tensor((1,), "float32")): + y = x + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + y = add(w, w) + return y + + +@pytest.mark.xfail +def test_var_if_scoping_fail(): + # TODO: fix this + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(cond: Tensor((), "bool"), x: Tensor((1,), "float32")): + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + y = add(w, w) + return w + + +def test_if_mismatch_var_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(cond: Tensor((), "bool"), x: Tensor((1,), "float32")): + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + z = add(w, w) + return z + + +def test_unassigned_call_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor(_, _)): + add(x, x) + return x + + +def test_tuple(): + @R.function + def f(x: Tensor(_, _), y: Tensor((32,), "float32")): + t: Tuple(Tensor(_, _), Tensor((32,), "float32")) = (x, y) + return t + + x, y = f.params + t_bind = f.body.blocks[0].bindings[0] + t, tup = t_bind.var, t_bind.value + + annot = t._checked_type_ + assert isinstance(annot, relay.TupleType) + assert isinstance(annot.fields[0], relax.ty.DynTensorType) and annot.fields[0].dtype == "" + assert ( + isinstance(annot.fields[1], relax.ty.DynTensorType) and annot.fields[1].dtype == "float32" + ) + + assert t.shape_ is None + + assert isinstance(tup, relax.Tuple) + assert_structural_equal(tup.fields, [x, y]) + assert tup.shape_ is None + check_shape(tup.fields[0], relax.RuntimeDepShape()) + check_shape(tup.fields[1], (32,)) + + +def test_tuplegetitem(): + @R.function + def f(x: Tensor(_, _), y: Tensor(_, _)): + t1 = relax.Tuple((x, y)) + t2 = (x, y) + a = t1[0] + b = relax.TupleGetItem(t2, 1) + c = add(a, b) + return c + + x, y = f.params + bind_0 = f.body.blocks[0].bindings[0] + bind_1 = f.body.blocks[0].bindings[1] + bind_2 = f.body.blocks[0].bindings[2] + bind_3 = f.body.blocks[0].bindings[3] + bind_4 = f.body.blocks[0].bindings[4] + assert_structural_equal(bind_0.value.fields, [x, y]) + assert_structural_equal(bind_1.value.fields, [x, y]) + assert isinstance(bind_0.value, relax.expr.Tuple) + assert isinstance(bind_1.value, relax.expr.Tuple) + assert isinstance(bind_2.value, relax.TupleGetItem) + assert isinstance(bind_3.value, relax.TupleGetItem) + assert bind_2.value.index == 0 + assert bind_3.value.index == 1 + assert bind_2.var.name_hint == "a" + assert bind_3.var.name_hint == "b" + check_call(bind_4.value, "add", [bind_2.var, bind_3.var]) + + +def test_local_func(): + @R.function + def f(x: Tensor(_, _)): + @R.function + def bar(y: Tensor(_, _)): + return y + + y = bar(x) # tests local function variable scoping + return y + + bar_bind, y_bind = f.body.blocks[0].bindings + bar, bar_fn = bar_bind.var, bar_bind.value + bar_x = y_bind.value + + assert isinstance(bar_fn, relax.Function) + assert bar_fn.body.body == bar_fn.params[0] + + assert bar_x.op == bar + + +def test_dataflow(): + @R.function + def f(x: Tensor(_, _)): + with relax.dataflow(): + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + relax.output(y, w) + t = divide(y, w) + return t + + assert len(f.body.blocks) == 2 + df_block = f.body.blocks[0] + y_bind, z_bind, w_bind = df_block.bindings + (t_bind,) = f.body.blocks[1].bindings + x = f.params[0] + y, z, w, t = map(lambda b: b.var, [y_bind, z_bind, w_bind, t_bind]) + + assert isinstance(y, relax.Var) + assert isinstance(z, relax.DataflowVar) + assert isinstance(w, relax.Var) + + check_call(y_bind.value, "add", [x, x]) + check_call(z_bind.value, "multiply", [y, x]) + check_call(w_bind.value, "subtract", [z, x]) + check_call(t_bind.value, "divide", [y, w]) + + assert f.body.body == t + + +def test_dataflow_match_shape(): + @R.function + def f(x: Tensor(_, _)): + with relax.dataflow(): + x2: Tensor((n, m), _) = relax.match_shape(x, (n, m)) + y = add(x2, x2) + z = multiply(y, x) + relax.match_shape(z.shape, (n, m)) + w: Tensor((n, m), _) = subtract(z, x) + relax.output(y, w, x2) + t: Tensor((n, m), _) = divide(y, w) + q: Tensor((n, m), _) = add(t, x2) + return q + + x = f.params[0] + df_block = f.body.blocks[0] + x2_bind = df_block.bindings[0] + z_shape_bind = df_block.bindings[3] + q_bind = f.body.blocks[1].bindings[1] + + assert x2_bind.var.name_hint == "x2" + check_tensor_var(x2_bind.var, ("n", "m"), "") + check_shape(x2_bind.pattern, ("n", "m")) + assert x2_bind.value == x + + check_shape(z_shape_bind.pattern, ("n", "m")) + + assert q_bind.value.args[1] == x2_bind.var + + +@pytest.mark.xfail +def test_dataflow_scope_fail(): + with pytest.raises(tvm.error.DiagnosticError): + # FIXME + @R.function + def f(x: Tensor(_, _)): + with relax.dataflow(): + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + relax.output(y, w) + t = divide(y, z) + return t + + +def test_dataflow_syntax_fail_pattern(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor(_, _)): + with relax.dataflow() as df: + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + relax.output(y, z) + t = divide(y, z) + return t + + +def test_dataflow_syntax_fail_params(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor(_, _)): + with relax.dataflow(x) as df: + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + relax.output(y, w) + t = divide(y, z) + return t + + +def test_dataflow_unbound_outputs(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor(_, _)): + with relax.dataflow(): + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + relax.output(x, y, w, q) + t = divide(y, z) + return t + + +def test_invalid_special_op_dataflow(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor): + y = add(x, x) + z = relax.dataflow() + return z + + +def test_invalid_special_op_output(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor): + y = add(x, x) + z = relax.output(y) + return z + + +def test_func_no_return_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor(_, _)): + y = add(x, x) + + +def test_call_tir(): + @R.function + def foo(x: Tensor((m, n), "float32")): + gv0 = relax.call_tir("test.op.identity", (x,), (m, n), dtype="float32") + return gv0 + + call_tir_node = foo.body.blocks[0].bindings[0].value + assert call_tir_node.attrs is None + assert_structural_equal( + call_tir_node.type_args[0], relax.DynTensorType(ndim=2, dtype="float32") + ) + + +def test_inline_tir(): + @R.function + def f(x: Tensor((B, 128), "float32"), y: Tensor((128, 128), "float32")): + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + + z = relax.call_tir(my_matmul, (x, y), (B, 128), dtype="float32") + return z + + x, y = f.params + B = x.shape_[0] + mm_bind, z_bind = f.body.blocks[0].bindings + + assert mm_bind.var.name_hint == "my_matmul" + assert isinstance(mm_bind.value, tir.PrimFunc) + + check_call( + z_bind.value, + "relax.call_tir", + [mm_bind.var, relax.Tuple([x, y]), relax.ShapeExpr([B, tir.IntImm("int64", 128)])], + ) + + +def test_call_packed(): + @R.function + def f(x: Tensor((3, 3), "float32")): + z: Tensor((n, m), "float32") = relax.call_packed( + "contrib.my_matmul", + x, + x, + mp=False, + type_args=(Tensor(ndim=2, dtype="float32")), + ) + + w = relax.call_packed( + "contrib.my_shape_of", + x, + dtype="int32", + attrs_type_key="relay.attrs.ShapeOfAttrs", + type_args=(Shape), + ) + + o = relax.call_packed("contrib.tensor_array_stack", x, z, type_args=(Object)) + + k = relax.call_packed( + "contrib.construct_tuple", + x, + x, + type_args=(Tuple(Tuple(Tensor(ndim=2, dtype="float32"), Tensor), Tensor)), + ) + return k + + x = f.params[0] + (z_bind, w_bind, o_bind, k_bind) = f.body.blocks[0].bindings + + z_var, z_value = z_bind.var, z_bind.value + check_tensor_var(z_var, ("n", "m"), "float32") + + assert isinstance(z_value.op, relax.ExternFunc) + assert z_value.op.global_symbol == "contrib.my_matmul" + assert "mp" in z_value.attrs and z_value.attrs["mp"] == False + assert_structural_equal(z_value.args, [x, x]) + assert len(z_value.type_args) == 1 + assert_structural_equal(z_value.type_args[0], relax.ty.DynTensorType(2, "float32")) + + w_value = w_bind.value + assert isinstance(w_value.attrs, relay.op.op_attrs.ShapeOfAttrs) + assert_structural_equal(w_value.type_args[0], relax.ty.ShapeType()) + + o_value = o_bind.value + assert_structural_equal(o_value.type_args[0], relax.ty.ObjectType()) + + k_value = k_bind.value + assert_structural_equal( + k_value.type_args[0], + relax.ty.TupleType( + [ + relax.TupleType( + [relax.ty.DynTensorType(2, "float32"), relax.ty.DynTensorType(-1, None)] + ), + relax.ty.DynTensorType(-1, None), + ] + ), + ) + + +def test_call_packed_no_type_args_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor((3, 3), "float32")): + z: Tensor((n, m), "float32") = relax.call_packed("contrib.my_matmul", x, x) + return z + + +def test_call_packed_wrong_type_args_fail(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def f(x: Tensor((3, 3), "float32")): + z: Tensor((n, m), "float32") = relax.call_packed( + "contrib.my_matmul", x, x, type_args=(Tuple) + ) + return z + + +def test_constant(): + @R.function + def f(x: Tensor((2, 3), "float32")): + y1 = relax.const(2, dtype="float32") + y2 = relax.const([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]]) + z = add(x, y1) + r = add(z, y2) + return r + + x = f.params[0] + bind_0 = f.body.blocks[0].bindings[0] + assert bind_0.var.name_hint == "y1" + bind_1 = f.body.blocks[0].bindings[1] + assert bind_1.var.name_hint == "y2" + bind_2 = f.body.blocks[0].bindings[2] + assert bind_2.var.name_hint == "z" + bind_3 = f.body.blocks[0].bindings[3] + assert bind_3.var.name_hint == "r" + check_call(bind_2.value, "add", [x, bind_0.var]) + check_call(bind_3.value, "add", [bind_2.var, bind_1.var]) + + +def test_primexpr_arithmetic(): + @R.function + def f(x: Tensor((n, m), "float32")): + z: Tensor((n * m,), "float32") = relax.call_packed( + "my_flatten", (x,), type_args=(Tensor(ndim=2, dtype="float32")) + ) + sh: Shape = (n + m, n // m) + return z + + x = f.params[0] + n, m = x.shape_ + z_bind, sh_bind = f.body.blocks[0].bindings + + assert_structural_equal(z_bind.var.shape_.values, [tir.Mul(n, m)]) + assert_structural_equal(sh_bind.value.values, [tir.Add(n, m), tir.FloorDiv(n, m)]) + + +def test_call_tir_extern(): + @R.function + def f(x: Tensor) -> Tensor: + z = relax.call_tir("my_extern", (x,), (10,), dtype="float32") + return z + + x = f.params[0] + (z_bind,) = f.body.blocks[0].bindings + + check_call( + z_bind.value, + "relax.call_tir", + [ + relax.ExternFunc("my_extern"), + relax.Tuple([x]), + relax.ShapeExpr([tir.IntImm("int64", 10)]), + ], + ) + + +def test_empty_shape(): + @R.function + def f(x: Tensor((), "float32"), y: Tensor((), "float32")): + @T.prim_func + def scalar_add(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, ()) + B = T.match_buffer(b, ()) + C = T.match_buffer(c, ()) + + with T.block("add"): + C[()] = A[()] + B[()] + + z = relax.call_tir(scalar_add, (x, y), (), dtype="float32") + return z + + x, y = f.params + add_bind, z_bind = f.body.blocks[0].bindings + + assert add_bind.var.name_hint == "scalar_add" + assert isinstance(add_bind.value, tir.PrimFunc) + + check_call( + z_bind.value, + "relax.call_tir", + [add_bind.var, relax.Tuple([x, y]), relax.ShapeExpr([])], + ) + + +def test_class_irmodule(): + @tvm.script.ir_module + class MyModule: + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + + @R.function + def f(x: Tensor((n, n), _)) -> Tensor: + return g(x) + + @R.function + def g(y: Tensor((n, n), _)) -> Tensor: + return relax.call_tir(my_matmul, (y, y), (n, n), dtype="float32") + + @R.function + def j(y: Tensor((n, n), _)) -> Tensor: + with relax.dataflow(): + gv = relax.call_tir(my_matmul, (y, y), (n, n), dtype="float32") + gv1 = (gv, gv) + gv2 = gv1[1] + relax.output(gv2) + return gv2 + + @R.function + def h(x: Tensor((n, n), _), y: Tensor((n, n), _), z: Tensor((n, n), _)) -> Tensor: + _ = my_matmul(x, y, z) + return z + + @R.function + def k(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: + gv0 = relax.call_packed( + "test.vm.mul", x, w, type_args=(Tensor(ndim=2, dtype="float32")) + ) + return gv0 + + my_module = MyModule + assert isinstance(my_module, tvm.IRModule) + + R.parser.pretty_print(my_module) + # check that we can print TIR and Relax functions too using the same api. + R.parser.pretty_print(my_module["my_matmul"]) + R.parser.pretty_print(my_module["f"]) + + var_f = my_module.get_global_var("f") + var_g = my_module.get_global_var("g") + var_j = my_module.get_global_var("j") + var_k = my_module.get_global_var("k") + var_my_matmul = my_module.get_global_var("my_matmul") + f = my_module[var_f] + g = my_module[var_g] + j = my_module[var_j] + k = my_module[var_k] + + assert f.body.op == var_g + assert g.body.args[0] == var_my_matmul + + gv_bind = j.body.blocks[0].bindings[0] + assert gv_bind.value.checked_type.ndim == 2 + assert gv_bind.value.checked_type.dtype == "float32" + assert gv_bind.var.checked_type.ndim == 2 + assert gv_bind.var.checked_type.dtype == "float32" + check_shape(gv_bind.value, ("n", "n")) + check_shape(gv_bind.var, ("n", "n")) + + # check call_packed checked_type_ + gv0_bind = k.body.blocks[0].bindings[0] + assert gv0_bind.value.checked_type.dtype == "float32" + assert gv0_bind.value.checked_type.ndim == 2 + assert gv0_bind.var.checked_type.dtype == "float32" + assert gv0_bind.var.checked_type.ndim == 2 + + # check function type + j_type = j.checked_type + assert isinstance(j_type, relax.FuncType) + assert isinstance(j_type.ret_type, relax.DynTensorType) + assert j_type.ret_type.ndim == 2 + assert j_type.ret_type.dtype == "float32" + assert len(j_type.arg_types) == 1 + assert isinstance(j_type.arg_types[0], relax.DynTensorType) + assert j_type.arg_types[0].ndim == 2 + + # check SeqExpr type/shape + assert isinstance(j.body, relax.SeqExpr) + assert j.body.checked_type.dtype == "float32" + assert j.body.checked_type.ndim == 2 + check_shape(j.body, ("n", "n")) + + # check tuple type/shape + gv1_bind = j.body.blocks[0].bindings[1] + isinstance(gv1_bind.value, relax.Tuple) + isinstance(gv1_bind.value.checked_type, relax.TupleType) + isinstance(gv1_bind.var.checked_type, relax.TupleType) + assert gv1_bind.var.checked_type.fields[0].ndim == 2 + assert gv1_bind.var.checked_type.fields[0].dtype == "float32" + isinstance(gv1_bind.var.shape, relax.Tuple) + isinstance(gv1_bind.value.shape, relax.Tuple) + check_shape(gv1_bind.value.shape.fields[0], ("n", "n")) + check_shape(gv1_bind.value.shape.fields[1], ("n", "n")) + check_shape(gv1_bind.var.shape.fields[0], ("n", "n")) + check_shape(gv1_bind.var.shape.fields[1], ("n", "n")) + + # check TupleGetItem type/shape + gv2_bind = j.body.blocks[0].bindings[2] + isinstance(gv2_bind.value, relax.TupleGetItem) + assert gv2_bind.value.checked_type.ndim == 2 + assert gv2_bind.value.checked_type.dtype == "float32" + assert gv2_bind.var.checked_type.ndim == 2 + assert gv2_bind.var.checked_type.dtype == "float32" + check_shape(gv2_bind.value.shape, ("n", "n")) + check_shape(gv2_bind.var, ("n", "n")) + + +def test_class_normalize(): + @tvm.script.ir_module + class InputModule: + @R.function + def mul_add(x: Tensor) -> Tensor: + return R.multiply(R.add(x, x), R.add(x, x)) + + # The parser automatically normalizes the input AST to the following ANF form + @tvm.script.ir_module + class OutputModule: + @R.function + def mul_add(x: Tensor) -> Tensor: + gv = relax.add(x, x) + gv1 = relax.add(x, x) + return R.multiply(gv, gv1) + + assert_structural_equal(InputModule, OutputModule) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_pass_manager.py b/tests/python/relax/test_pass_manager.py new file mode 100644 index 0000000000..288fd4cf08 --- /dev/null +++ b/tests/python/relax/test_pass_manager.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for relax pass manager.""" +from __future__ import annotations # must import to defer parsing of annotations +import numpy as np +import pytest +import tvm +from tvm import relax, ir +from tvm.ir.base import assert_structural_equal +from tvm.relax.expr import Call + +import tvm.script +from tvm.script import tir as T, relax as R + + +def check_equal(mod1, mod2): + mod1 = relax.transform.Normalize()(mod1) + mod2 = relax.transform.Normalize()(mod2) + assert_structural_equal(mod1, mod2) + + +def test_function_class_pass(): + @relax.transform.function_pass(opt_level=1) + class TestReplaceFunc: + """Simple test function to replace one argument to another.""" + + def __init__(self, new_func): + self.new_func = new_func + + def transform_function(self, func, mod, ctx): + return self.new_func + + @tvm.script.ir_module + class Before: + @R.function + def f1(x: Tensor((m, n), "float32")): + return x + + @tvm.script.ir_module + class Expected: + @R.function + def f2(x: Tensor((m, n), "float32")): + gv0 = relax.add(x, x) + return gv0 + + fpass = TestReplaceFunc(Expected["f2"]) + assert fpass.info.opt_level == 1 + assert fpass.info.name == "TestReplaceFunc" + After = fpass(Before) + assert_structural_equal(After["f1"], Expected["f2"]) + + +# Swap Multiply and Add Ops +@relax.expr_functor.mutator +class SwapMAVar(relax.PyExprMutator): + def __init__(self) -> None: + super().__init__() + + def visit_call_(self, call: Call) -> Call: + call = self.visit_expr_post_order(call) + if call.op == ir.Op.get("relax.add"): + new_op = ir.Op.get("relax.multiply") + elif call.op == ir.Op.get("relax.multiply"): + new_op = ir.Op.get("relax.add") + else: + new_op = self.visit_expr(call.op) + + new_call = Call(new_op, call.args, call.attrs, call.type_args, call.span) + return self.builder_.normalize(new_call) + + +def test_function_pass(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: Tensor((m, n), "float32"), y: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.multiply(x, y) + gv0 = relax.add(lv0, y) + relax.output(gv0) + gv1 = relax.multiply(x, y) + gv2 = relax.add(gv1, y) + return (gv0, gv1, gv2) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor((m, n), "float32"), y: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.add(x, y) + gv0 = relax.multiply(lv0, y) + relax.output(gv0) + gv1 = relax.add(x, y) + gv2 = relax.multiply(gv1, y) + return (gv0, gv1, gv2) + + pass_name = "function_pass_test" + opt_level = 0 + + # create FunctionPass with the function_pass decorator + @relax.transform.function_pass(opt_level=opt_level, name=pass_name) + def decorator_transform(func, mod, ctx): + return SwapMAVar().visit_expr(func) + + # check the transform info + assert isinstance(decorator_transform, relax.transform.FunctionPass) + assert decorator_transform.info.name == pass_name + assert decorator_transform.info.opt_level == opt_level + # run the transform + After = decorator_transform(Before) + check_equal(After, Expected) + + # create FunctionPass directly with the function_pass init call + def direct_transform(func, mod, ctx): + return SwapMAVar().visit_expr(func) + + direct_transform = relax.transform.function_pass(direct_transform, opt_level=opt_level) + assert isinstance(direct_transform, relax.transform.FunctionPass) + assert direct_transform.info.name == "direct_transform" + assert direct_transform.info.opt_level == opt_level + # run the transform + After = direct_transform(Before) + check_equal(After, Expected) + + +def test_dataflowblock_class_pass(): + @relax.transform.dataflowblock_pass(opt_level=1) + class TestReplaceBinding: + """Simple test function to replace the first VarBinding to another.""" + + def __init__(self): + # create a new VarBinding + type_anno = relax.DynTensorType(2, "float32") + lv0 = relax.Var("lv1", (2, 2), type_anno) + val = relax.const(np.random.rand(24, 56)) + self.new_binding = relax.VarBinding(lv0, val) + + def transform_dataflowblock(self, block, mod, ctx): + bindings = block.bindings + new_bindings = [self.new_binding, bindings[1]] + new_block = relax.expr.DataflowBlock(new_bindings, block.span) + return new_block + + @tvm.script.ir_module + class Mod1: + @R.function + def f(x: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.multiply(x, x) + gv0 = relax.add(x, x) + relax.output(gv0) + return gv0 + + @tvm.script.ir_module + class Mod2: + @R.function + def f(x: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.add(x, x) + gv0 = relax.add(x, x) + relax.output(gv0) + return gv0 + + block_pass = TestReplaceBinding() + assert block_pass.info.opt_level == 1 + assert block_pass.info.name == "TestReplaceBinding" + updated_mod1 = block_pass(Mod1) + updated_mod2 = block_pass(Mod2) + assert_structural_equal(updated_mod1["f"], updated_mod2["f"]) + + +def test_dataflowblock_pass(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: Tensor((m, n), "float32"), y: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.multiply(x, y) + gv0 = relax.add(lv0, y) + relax.output(gv0) + gv1 = relax.multiply(x, y) + gv2 = relax.add(gv1, y) + return (gv0, gv1, gv2) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor((m, n), "float32"), y: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.add(x, y) + gv0 = relax.multiply(lv0, y) + relax.output(gv0) + gv1 = relax.multiply(x, y) + gv2 = relax.add(gv1, y) + return (gv0, gv1, gv2) + + pass_name = "dataflow_pass_test" + opt_level = 0 + + # create DataflowBlockPass with the dataflowblock_pass decorator + @relax.transform.dataflowblock_pass(opt_level=opt_level, name=pass_name) + def decorator_transform(block, mod, ctx): + return SwapMAVar().visit_binding_block(block) + + # check the transform info + assert isinstance(decorator_transform, relax.transform.DataflowBlockPass) + assert decorator_transform.info.name == pass_name + assert decorator_transform.info.opt_level == opt_level + # run the transform + After = decorator_transform(Before) + check_equal(After, Expected) + + # create DataflowBlockPass directly with the dataflowblock_pass init call + def direct_transform(block, mod, ctx): + return SwapMAVar().visit_binding_block(block) + + direct_transform = relax.transform.dataflowblock_pass(direct_transform, opt_level=opt_level) + assert isinstance(direct_transform, relax.transform.DataflowBlockPass) + assert direct_transform.info.name == "direct_transform" + assert direct_transform.info.opt_level == opt_level + # run the transform + After = direct_transform(Before) + check_equal(After, Expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py new file mode 100644 index 0000000000..2dd2a91846 --- /dev/null +++ b/tests/python/relax/test_printer.py @@ -0,0 +1,410 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations # must import to defer parsing of annotations + +import pytest +import tvm + +from tvm import relax +from tvm import tir, relay +from tvm.ir import structural_equal, assert_structural_equal + +import tvm.script +from tvm.relax.utils import metadata_partitioner +from tvm.script import tir as T, relax as R + + +def check_roundtrip(f_pre): + relax_text = R.parser.astext(f_pre, show_meta_data=True) + f_post = R.parser.from_source(input_func=relax_text) + if isinstance(f_pre, tvm.IRModule) and not isinstance(f_post, tvm.IRModule): + global_vars = f_pre.get_global_vars() + f_post = tvm.IRModule({global_vars[0]: f_post}, attrs=metadata) + assert_structural_equal(f_pre, f_post, map_free_vars=True) + + +def test_annotations(): + @R.function + def foo(x: Tensor((32, m), "float32"), y: Tensor((m, k), "float32")) -> Tensor: + z: Tensor((32, k), "float32") = nn.matmul(x, y, units=None) + w: Tensor(_, _) = multiply(z, z) + t = subtract(w, z) + sh: Shape = t.shape + return t + + check_roundtrip(foo) + + +def test_ndim_annotations(): + @R.function + def foo( + x: Tensor((2, 3, 5), "float32", ndim=3), + y: Tensor(_, "float32", ndim=-1), + z: Tensor(_, "float32", ndim=2), + ): + w: Tensor(None, "float32", ndim=-1) = x + x + return w + + check_roundtrip(foo) + + +def test_match_shape(): + @R.function + def foo(x: Tensor(_, "float32")): + relax.match_shape(x.shape, (n, m)) + y: Tensor((n, m), "float32") = add(x, x) + return x + + check_roundtrip(foo) + + +def test_if(): + @R.function + def foo(cond: Tensor((), "bool"), x: Tensor((1,), "float32")): + if cond: + w = add(x, x) + y = multiply(w, w) + else: + w = multiply(x, x) + y = add(w, w) + return y + + check_roundtrip(foo) + + +def test_tuple(): + @R.function + def foo(x: Tensor(_, _), y: Tensor((32,), "float32")): + t: Tuple(Tensor(_, _), Tensor((32,), "float32")) = (x, y) + return t + + check_roundtrip(foo) + + +def test_tuplegetitem(): + @R.function + def foo(x: Tensor(_, _)): + y = add(x, x) + z = multiply(y, x) + t = relax.Tuple((y, z)) + a = relax.TupleGetItem(t, 0) + b = relax.TupleGetItem(t, 1) + c = divide(a, b) + return c + + check_roundtrip(foo) + + +def test_local_func(): + @R.function + def foo(x: Tensor(_, _)): + @R.function + def bar(y: Tensor(_, _)): + return y + + y = bar(x) # tests local function variable scoping + return y + + check_roundtrip(foo) + + +def test_dataflow(): + @R.function + def foo(x: Tensor(_, _)): + with relax.dataflow(): + # TODO: parse this + # nonlocal y, w + y = add(x, x) + z = multiply(y, x) + w = subtract(z, x) + relax.output(y, w) + t = divide(y, w) + return t + + check_roundtrip(foo) + + +def test_dataflow_match_shape(): + @R.function + def foo(x: Tensor(_, _)): + with relax.dataflow(): + x2: Tensor((n, m), _) = relax.match_shape(x, (n, m)) + y = add(x2, x2) + z = multiply(y, x) + relax.match_shape(z.shape, (n, m)) + w: Tensor((n, m), _) = subtract(z, x) + relax.output(y, w, x2) + t: Tensor((n, m), _) = divide(y, w) + q: Tensor((n, m), _) = add(t, x2) + return q + + check_roundtrip(foo) + + +def test_inline_tir(): + @R.function + def foo(x: Tensor((B, 128), "float32"), y: Tensor((128, 128), "float32")): + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + + z = relax.call_tir(my_matmul, (x, y), (B, 128), dtype="float32") + return z + + check_roundtrip(foo) + + +def test_call_packed(): + @R.function + def foo(x: Tensor((3, 3), "float32")): + # test that we can intro dim vars + z: Tensor((n, m), "float32") = relax.call_packed( + "contrib.my_matmul", x, x, mp=False, type_args=(Tensor(ndim=2, dtype="float32")) + ) + w = relax.call_packed( + "contrib.my_shape_of", + x, + dtype="int32", + attrs_type_key="relay.attrs.ShapeOfAttrs", + type_args=(Shape), + ) + o = relax.call_packed("contrib.tensor_array_stack", x, z, type_args=(Object)) + return z + + check_roundtrip(foo) + + +def test_primexpr_arithmetic(): + @R.function + def foo(x: Tensor((n, m), "float32")): + z: Tensor((n * m,), "float32") = relax.call_packed( + "my_flatten", (x,), type_args=(Tensor(ndim=2, dtype="float32")) + ) + sh: Shape = (n + m, n // m) + return z + + check_roundtrip(foo) + + +def test_call_tir_extern(): + @R.function + def foo(x: Tensor): + z = relax.call_tir("my_extern", (x,), (10,), dtype="float32") + return z + + check_roundtrip(foo) + + +def test_const_irmodule(): + def _gen_meta_data(): + @tvm.script.ir_module + class Module: + @R.function + def my_const(x: Tensor((2, 3), "float32")): + y: Tensor((2, 3), "float32") = relax.const( + [[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]], dtype="float32" + ) + z: Tensor((2, 3), "float32") = relax.add(x, y) + return z + + mod = Module + relax_text = R.parser.astext(mod, show_meta_data=True) + texts = metadata_partitioner(relax_text) + return texts[1] + + json_str = _gen_meta_data() + + @tvm.script.ir_module(metadata=json_str) + class MyModule: + @R.function + def my_const(x: Tensor((2, 3), "float32")): + z: Tensor((2, 3), "float32") = relax.add(x, meta[relay.Constant][0]) + return z + + my_module = MyModule + + check_roundtrip(my_module) + + +def test_const(): + @R.function + def my_const(x: Tensor((2, 3), "float32")): + y1 = relax.const([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]]) + y2 = relax.const(2.1, dtype="float32") + y3 = relax.const([[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]) + z = relax.add(x, y1) + r = relax.add(z, y2) + w = relax.add(r, y3) + return w + + check_roundtrip(my_const) + + +def test_const_meta(): + def _get_meta_data(): + @R.function + def my_const(x: Tensor((2, 3), "float32")): + y1: Tensor((2, 3), "float32") = relax.const([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]]) + y2 = relax.const(2.1, dtype="float32") + y3: Tensor((2, 3), "float32") = relax.const([[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]) + z: Tensor((2, 3), "float32") = relax.add(x, y1) + r: Tensor((2, 3), "float32") = relax.add(z, y2) + w: Tensor((2, 3), "float32") = relax.add(r, y3) + return w + + relax_text = R.parser.astext(my_const, show_meta_data=True) + texts = metadata_partitioner(relax_text) + return texts[1] + + json_str = _get_meta_data() + + @R.function(metadata=json_str) + def my_const(x: Tensor((2, 3), "float32")): + y2 = relax.const(2.1, dtype="float32") + z: Tensor((2, 3), "float32") = relax.add(x, meta[relay.Constant][0]) + r: Tensor((2, 3), "float32") = relax.add(z, y2) + w: Tensor((2, 3), "float32") = relax.add(r, meta[relay.Constant][1]) + return w + + check_roundtrip(my_const) + + +def test_class_irmodule(): + @tvm.script.ir_module + class MyModule: + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + + @R.function + def f(x: Tensor((n, n), _)) -> Tensor: + # todo(@yongwww): Update the check_type_ function's body is a call_node + r = g(x) + return r + + @R.function + def g(y: Tensor((n, n), _)) -> Tensor: + r = relax.call_tir(my_matmul, (y, y), (n, n), dtype="float32") + return r + + @R.function + def h(x: Tensor((n, n), _), y: Tensor((n, n), _), z: Tensor((n, n), _)) -> Tensor: + _ = my_matmul(x, y, z) + return z + + my_module = MyModule + check_roundtrip(my_module) + + +def test_tir_max(): + @R.function + def tir_max(x: Tensor((m, n), "float32")): + gv = relax.call_tir("my_extern", (x,), (tir.max(n, m),), dtype="float32") + return gv + + check_roundtrip(tir_max) + + +def test_tir_cast(): + @R.function + def tir_cast(x: Tensor((m,), "float32")): + gv = relax.call_tir("my_extern", (x,), (tir.cast("int32", m),), dtype="float32") + return gv + + check_roundtrip(tir_cast) + + +def test_dyntensor_type(): + x = relax.DynTensorType(ndim=3, dtype="float32") + assert x.__str__() == 'Tensor[ndim=3, dtype="float32"]' + + +def test_object_type(): + x = relax.ObjectType() + assert x.__str__() == "Object" + + +def test_shape_expr(): + x = relax.ShapeExpr([tir.IntImm("int64", 10), tir.IntImm("int64", 5)]) + assert x.__str__() == "(10, 5)" + + +def test_runtime_dep_shape(): + x = relax.RuntimeDepShape() + assert x.__str__() == "_" + + +def test_func_type(): + @tvm.script.ir_module + class TestFuncType: + @R.function + def global_func_1( + x: Tensor((m, n), "float32") + ) -> Callable((Tensor((m, n), "float32")), Tensor((m, n), "float32")): + @R.function + def local_func_1(y: Tensor((m, n), "float32")) -> Tensor((m, n), "float32"): + s = relax.add(x, y) + return s + + return local_func_1 + + @R.function + def global_func_2( + x: Tensor((m, n), "float32") + ) -> Callable( + (Tensor(None, "float32", ndim=2)), + Callable((Tensor((m, n), "float32"),), Tensor((m, n), "float32")), + ): + @R.function + def local_func_1( + y: Tensor((m, n), "float32") + ) -> Callable((Tensor((m, n), "float32"),), Tensor((m, n), "float32")): + @R.function + def local_func_2(z: Tensor((m, n), "float32")) -> Tensor(None, "float32", ndim=2): + s1 = relax.add(x, y) + s2 = relax.add(z, s1) + return s2 + + return local_func_2 + + return local_func_1 + + func_type = TestFuncType + check_roundtrip(func_type) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py new file mode 100644 index 0000000000..acb560d9fc --- /dev/null +++ b/tests/python/relax/test_relax_operators.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations # must import to defer parsing of annotations +import sys +import tempfile +import pytest +import tvm +from tvm import relax +from tvm._ffi.base import TVMError + +from tvm.script import relax as R + +import numpy as np + + +@tvm.script.ir_module +class InputModule: + @R.function + def foo(x: Tensor((m, n), "int64")): + y = relax.unique(x, sorted=False) + y_sorted = relax.unique(x) + return y, y_sorted + + +def run_cpu(mod, func_name, *input): + target = tvm.target.Target("llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + return vm[func_name](*input) + + +def test_unique(): + + # TODO(prakalp): also add test for compiling and running on cuda device. + data_numpy = np.random.randint(0, 16, (16, 16)) + data = tvm.nd.array(data_numpy) + result, result_sorted = run_cpu(InputModule, "foo", data) + + expected_output_sorted, indices = np.unique(data_numpy, return_index=True) + expected_output = [data_numpy.flatten()[index] for index in sorted(indices, reverse=True)] + + np.testing.assert_array_equal(expected_output_sorted, result_sorted.numpy()) + np.testing.assert_array_equal(expected_output, result.numpy()) + + +@tvm.script.ir_module +class PrintTest: + @R.function + def foo(x: Tensor((), "int32")): + # results have to be bound, but we don't use them + # TODO: We should allow calls whose results are not bound for side effects; + # it would be easy syntactic sugar to add. + p1 = relax.print(x) + p2 = relax.print(x, format="Number: {}") + t = (x, x) + p3 = relax.print(t, format="Tuple: {}") + p4 = relax.print(x, t) + p5 = relax.print(x, x, format="Custom print: {} {}") + p6 = relax.print(x, t, format="Another print: {} {}") + return x + + +def test_print(): + try: + stdout = sys.stdout + with tempfile.TemporaryFile(mode="w+") as test_out: + sys.stdout = test_out + run_cpu(PrintTest, "foo", tvm.nd.array(1)) + test_out.seek(0) + printed_text = str(test_out.read()) + expected = "1\nNumber: 1\nTuple: (1, 1)\n1 (1, 1)\nCustom print: 1 1\nAnother print: 1 (1, 1)\n" + assert printed_text == expected + finally: + sys.stdout = stdout + + +@tvm.script.ir_module +class AssertOpTest: + @R.function + def passes(x: Tensor((), "int32")): + p1 = relax.assert_op(relax.const(True)) + return x + + @R.function + def pass_with_args(x: Tensor((), "int32")): + p1 = relax.assert_op(relax.const(True), x, format="You won't see me") + return x + + @R.function + def simple_fail(x: Tensor((), "int32")): + p1 = relax.assert_op(relax.const(False)) + return x + + @R.function + def fail_with_message(x: Tensor((), "int32")): + p1 = relax.assert_op(relax.const(False), format="I failed...") + return x + + @R.function + def fail_with_args(x: Tensor((), "int32")): + # no format + p1 = relax.assert_op(relax.const(False), x, x) + return x + + @R.function + def fail_with_formatted_message(x: Tensor((), "int32")): + p1 = relax.assert_op(relax.const(False), x, format="Number: {}") + return x + + +def test_assert_op(): + def check_assertion_error(func_name, func_arg, expected_message): + passed = False + try: + run_cpu(AssertOpTest, func_name, func_arg) + passed = True + except TVMError as e: + # TVM will print out a TVMError that will contain the + # generated error at the bottom of a stack trace + assert "AssertionError" in e.args[0] + assert expected_message in e.args[0] + assert not passed + + run_cpu(AssertOpTest, "passes", tvm.nd.array(1)) + run_cpu(AssertOpTest, "pass_with_args", tvm.nd.array(2)) + check_assertion_error("simple_fail", tvm.nd.array(3), "Assertion Failed") + check_assertion_error("fail_with_message", tvm.nd.array(4), "I failed...") + check_assertion_error("fail_with_args", tvm.nd.array(5), "5, 5") + check_assertion_error("fail_with_formatted_message", tvm.nd.array(6), "Number: 6") + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_relax_usmp_analysis_assign_poolinfo.py b/tests/python/relax/test_relax_usmp_analysis_assign_poolinfo.py new file mode 100644 index 0000000000..80a1a501a9 --- /dev/null +++ b/tests/python/relax/test_relax_usmp_analysis_assign_poolinfo.py @@ -0,0 +1,201 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations # must import to defer parsing of annotations + +import sys +import pytest +import tvm +from tvm.ir import Span +from tvm.ir.memory_pools import WorkspaceMemoryPools +from tvm.relax import expr_functor, PyExprVisitor, PyExprMutator, Expr +from tvm.relax.testing import dump_ast +import tvm.script +import tvm.testing +from tvm import relax, rpc, te, tir, topi, TVMError, cpu, WorkspacePoolInfo, ConstantPoolInfo +from tvm.script import relax as R, tir as T +from tvm.target import Target + + +def _check_pool_infos(mod, target_to_pool_info): + @relax.expr_functor.visitor + class RelaxFuncCheck(PyExprVisitor): + def visit_span(self, span: Span): + pass + + def visit_call_(self, op: tvm.relax.Call): + call = op + if "relax.builtin.alloc_tensor" == str(call.op): + candidate_pools = call.attrs["candidate_memory_pools"] + assert candidate_pools[0] == pool_info + return super().visit_call_(op) + + def check_poolinfos(stmt): + if isinstance(stmt, tvm.tir.Allocate): + assert stmt.annotations["candidate_memory_pools"][0] == pool_info + return stmt + + for global_var, basefunc in mod.functions.items(): + pool_info = target_to_pool_info[basefunc.attrs["target"]] + if isinstance(basefunc, tvm.relax.Function): + RelaxFuncCheck().visit_expr(basefunc) + if isinstance(basefunc, tvm.tir.PrimFunc): + basefunc.with_body( + tvm.tir.stmt_functor.ir_transform(basefunc.body, None, check_poolinfos) + ) + + +device = cpu(0) + +# fmt: off +@tvm.script.ir_module +class RelaxAndTIR: + @T.prim_func + def prim_func_1(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "prim_func_1", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + Conv2dOutput_7[ff_3] = 0 + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) + for ax3_inner_7 in T.serial(0, 64): + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") + + @R.function + def __tvm_main__(input: Tensor((16, 16), "uint8")) -> Tensor: + tsid_11 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + tsid_12 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + + output = relax.call_tir("prim_func_1", (input, tsid_11, tsid_12), (802816, 1), dtype="int32") + return output +# fmt: on + + +def test_relax_and_tir(): + target = Target("c") + target_llvm = Target("llvm") + relax_mod = RelaxAndTIR + passes = [relax.transform.ToNonDataflow(), relax.transform.CallTIRRewrite()] + seq = tvm.transform.Sequential(passes) + relax_mod = seq(relax_mod) + + relax_mod["__tvm_main__"] = relax_mod["__tvm_main__"].with_attr("target", target) + relax_mod["prim_func_1"] = relax_mod["prim_func_1"].with_attr("target", target_llvm) + + c_target_pool = WorkspacePoolInfo(pool_name="c_target_pool", targets=[target]) + llvm_target_pool = WorkspacePoolInfo(pool_name="llvm_target_pool", targets=[target_llvm]) + workspace_memory_pools = WorkspaceMemoryPools([c_target_pool, llvm_target_pool]) + relax_mod = relax_mod.with_attr("workspace_memory_pools", workspace_memory_pools) + + relax_mod = tvm.relax.transform.AssignPoolInfo()(relax_mod) + _check_pool_infos(relax_mod, {target: c_target_pool, target_llvm: llvm_target_pool}) + + +# fmt: off +@tvm.script.ir_module +class RelaxAndTIRMultipleTargets: + @T.prim_func + def prim_func_1(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "prim_func_1", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + Conv2dOutput_7[ff_3] = 0 + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) + for ax3_inner_7 in T.serial(0, 64): + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") + + @T.prim_func + def prim_func_2(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") + + @R.function + def __tvm_main__(input: Tensor((16, 16), "uint8")) -> Tensor: + tsid_11 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + tsid_12 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + + lv1 = relax.call_tir("prim_func_1", (input, tsid_11, tsid_12), (802816, 1), dtype="int32") + output = relax.call_tir("prim_func_2", (lv1), (802816, 1), dtype="int32") + return output +# fmt: on + + +def test_relax_and_tir_multiple_targets(): + target = Target("c") + target_llvm = Target("llvm") + target_cuda = Target("cuda") + relax_mod = RelaxAndTIRMultipleTargets + passes = [relax.transform.ToNonDataflow(), relax.transform.CallTIRRewrite()] + seq = tvm.transform.Sequential(passes) + relax_mod = seq(relax_mod) + + relax_mod["__tvm_main__"] = relax_mod["__tvm_main__"].with_attr("target", target) + relax_mod["prim_func_1"] = relax_mod["prim_func_1"].with_attr("target", target_llvm) + relax_mod["prim_func_2"] = relax_mod["prim_func_2"].with_attr("target", target_cuda) + + c_target_pool = WorkspacePoolInfo(pool_name="c_target_pool", targets=[target]) + llvm_target_pool = WorkspacePoolInfo(pool_name="llvm_target_pool", targets=[target_llvm]) + cuda_target_pool = WorkspacePoolInfo(pool_name="cuda_target_pool", targets=[target_cuda]) + workspace_memory_pools = WorkspaceMemoryPools( + [c_target_pool, llvm_target_pool, cuda_target_pool] + ) + relax_mod = relax_mod.with_attr("workspace_memory_pools", workspace_memory_pools) + + relax_mod = tvm.relax.transform.AssignPoolInfo()(relax_mod) + _check_pool_infos( + relax_mod, + {target: c_target_pool, target_llvm: llvm_target_pool, target_cuda: cuda_target_pool}, + ) + + +if __name__ == "__main__": + pytest.main([__file__] + sys.argv[1:]) diff --git a/tests/python/relax/test_relax_usmp_analysis_extract_bufferinfo.py b/tests/python/relax/test_relax_usmp_analysis_extract_bufferinfo.py new file mode 100644 index 0000000000..5ea8a0b2df --- /dev/null +++ b/tests/python/relax/test_relax_usmp_analysis_extract_bufferinfo.py @@ -0,0 +1,1996 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations # must import to defer parsing of annotations + +import sys +import pytest +import tvm +from tvm.ir import Span +from tvm.relax import expr_functor, PyExprVisitor, PyExprMutator, Expr +import tvm.script +import tvm.testing +from tvm import relax, rpc, te, tir, topi, TVMError, cpu, WorkspacePoolInfo, ConstantPoolInfo +from tvm.script import relax as R, tir as T +from tvm.target import Target + + +def _replace_stmt_with_buf_var_names(buffer_info_map): + """helper to replace tir.allocates with buffer names""" + new_buffer_info_map = dict() + for k, v in buffer_info_map.items(): + new_buffer_info_map[k.name_hint] = k + return new_buffer_info_map + + +def _verify_conflicts(main_buf_name, conflicting_buf_names, buffer_info_map): + """helper to check expected liveness conflicts""" + buf_info = buffer_info_map[main_buf_name] + for conflict in buf_info.conflicts: + assert conflict.name_hint in conflicting_buf_names + + +def _assign_poolinfos_to_allocates_in_primfuncs(func, pool_infos, constant_pool_infos): + """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" + + def set_poolinfos(stmt): + if isinstance(stmt, tvm.tir.Allocate): + return tvm.tir.Allocate( + buffer_var=stmt.buffer_var, + dtype=stmt.dtype, + extents=stmt.extents, + condition=stmt.condition, + body=stmt.body, + annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, + ) + elif isinstance(stmt, tvm.tir.AllocateConst): + return tvm.tir.AllocateConst( + buffer_var=stmt.buffer_var, + dtype=stmt.dtype, + extents=stmt.extents, + data_or_idx=stmt.data, + body=stmt.body, + annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: constant_pool_infos}, + ) + + return func.with_body(tvm.tir.stmt_functor.ir_transform(func.body, None, set_poolinfos)) + + +def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos, constant_pool_infos=None): + """helper to assign poolinfos to allocate nodes in a IRModule""" + + @relax.expr_functor.mutator + class RelaxFuncAnnotate(PyExprMutator): + def visit_span(self, span: Span) -> Span: + pass + + def visit_call_(self, op: tvm.relax.Call) -> Expr: + call = op + if "relax.builtin.alloc_tensor" == str(call.op): + attrs = tvm.ir.attrs.make_node( + "relax.attrs.AllocTensorAttrs", + dtype=call.attrs["dtype"], + runtime_device_index=call.attrs["runtime_device_index"], + candidate_memory_pools=pool_infos, + ) + return tvm.relax.Call(call.op, call.args, attrs, call.type_args, call.span) + return super().visit_call_(op) + + relax_visitor = RelaxFuncAnnotate() + mod["run_model"] = relax_visitor.visit_expr(mod["run_model"]) + + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = _assign_poolinfos_to_allocates_in_primfuncs( + basefunc, pool_infos, constant_pool_infos + ) + else: + ret[global_var] = basefunc + return ret + + +def _assign_targets_to_relaxfuncs_irmodule(mod, target): + """helper to assign target for PrimFunc in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, (tvm.relax.Function, tvm.tir.PrimFunc)): + ret[global_var] = basefunc.with_attr("target", target) + return ret + + +device = cpu(0) + +# fmt: off +@tvm.script.ir_module +class LinearStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + Conv2dOutput_7[ff_3] = 0 + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) + for ax3_inner_7 in T.serial(0, 64): + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") + + @R.function + def run_model(input: Tensor((16, 16), "uint8")) -> Tensor: + tsid_10 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + tsid_11 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + tsid_12 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + + lv0 = relax.call_tir("tvmgen_default_fused_cast_subtract", (input, tsid_10), (301056, 1), dtype="int32") + lv1 = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", (lv0, tsid_11, tsid_12), (802816, 1), dtype="int32") + output = relax.call_tir("tvmgen_default_fused_nn_max_pool2d_cast", (lv1), (16, 16), dtype="int32") + return output +# fmt: on + + +def test_linear(): + target = Target("c") + relax_mod = LinearStructure + passes = [relax.transform.ToNonDataflow(), relax.transform.CallTIRRewrite()] + seq = tvm.transform.Sequential(passes) + relax_mod = seq(relax_mod) + + fast_memory_pool = WorkspacePoolInfo(pool_name="fast_memory", targets=[target]) + slow_memory_pool = WorkspacePoolInfo(pool_name="slow_memory", targets=[target]) + + relax_mod = _assign_targets_to_relaxfuncs_irmodule(relax_mod, target) + relax_mod = _assign_poolinfos_to_allocates_in_irmodule( + relax_mod, [fast_memory_pool, slow_memory_pool] + ) + + buffer_info_analysis = tvm.relax.analysis.extract_buffer_info(relax_mod["run_model"], relax_mod) + buffer_info_map_relax = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts) + + assert buffer_info_analysis.memory_pressure == 3526168 + + # check conflicts + _verify_conflicts( + "PaddedInput_7", + ["alloc", "Conv2dOutput_7", "alloc1", "tsid_11", "tsid_12"], + buffer_info_map_relax, + ) + _verify_conflicts("tsid_10", ["alloc"], buffer_info_map_relax) + _verify_conflicts("alloc2", ["tensor_2"], buffer_info_map_relax) + _verify_conflicts("alloc", ["tsid_10", "PaddedInput_7"], buffer_info_map_relax) + _verify_conflicts( + "tsid_11", ["PaddedInput_7", "alloc1", "tsid_12", "Conv2dOutput_7"], buffer_info_map_relax + ) + _verify_conflicts( + "alloc1", + ["tsid_11", "PaddedInput_7", "tsid_12", "Conv2dOutput_7", "tensor_2"], + buffer_info_map_relax, + ) + _verify_conflicts( + "tsid_12", ["alloc1", "PaddedInput_7", "tsid_11", "Conv2dOutput_7"], buffer_info_map_relax + ) + _verify_conflicts( + "Conv2dOutput_7", ["tsid_12", "alloc1", "PaddedInput_7", "tsid_11"], buffer_info_map_relax + ) + _verify_conflicts("tensor_2", ["alloc1", "alloc2"], buffer_info_map_relax) + + # check sizes + assert buffer_info_map_relax["alloc"].size_bytes == 1204224 + assert buffer_info_map_relax["alloc1"].size_bytes == 3211264 + assert buffer_info_map_relax["alloc2"].size_bytes == 1024 + assert buffer_info_map_relax["Conv2dOutput_7"].size_bytes == 256 + assert buffer_info_map_relax["PaddedInput_7"].size_bytes == 314646 + assert buffer_info_map_relax["tensor_2"].size_bytes == 200704 + + # check_pool_candidates + assert [ + pool_info.pool_name for pool_info in list(buffer_info_map_relax["alloc"].pool_candidates) + ] == ["fast_memory", "slow_memory"] + + +# fmt: off +@tvm.script.ir_module +class ParallelSerialMixedForLoops: + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) + placeholder_71 = T.match_buffer(placeholder_68, [262144], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_8 = T.allocate([215296], "int16", "global") + for i0_i1_fused_8 in T.serial(0, 58): + for i2_8, i3_8 in T.grid(58, 64): + PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_8 in T.parallel(0, 3136): + dummy_allocate = T.allocate([1], "int32", "global") + for ax3_outer_4 in T.serial(0, 3): + Conv2dOutput_8 = T.allocate([64], "int32", "global") + for ff_4 in T.serial(0, 64): + Conv2dOutput_8[ff_4] = 0 + for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): + Conv2dOutput_8[ff_4] = (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))) + for ax3_inner_8 in T.serial(0, 64): + T_cast_23[(((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8") + + @R.function + def run_model(input: Tensor((512, 512), "uint8")) -> Tensor: + tsid_10 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + tsid_11 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + + output = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", (input, tsid_10, tsid_11), (262144, 1), dtype="int32") + return output +# fmt: on + + +# fmt: off +@tvm.script.ir_module +class AllSerialForLoops: + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) + placeholder_71 = T.match_buffer(placeholder_68, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_8 = T.allocate([215296], "int16", "global") + for i0_i1_fused_8 in T.serial(0, 58): + for i2_8, i3_8 in T.grid(58, 64): + PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): + dummy_allocate = T.allocate([1], "int32", "global") + for ax3_outer_4 in T.serial(0, 3): + Conv2dOutput_8 = T.allocate([64], "int32", "global") + for ff_4 in T.serial(0, 64): + Conv2dOutput_8[ff_4] = 0 + for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): + Conv2dOutput_8[ff_4] = (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))) + for ax3_inner_8 in T.serial(0, 64): + T_cast_23[(((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8") + + + @R.function + def run_model(input: Tensor((512, 512), "uint8")) -> Tensor: + tsid_10 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + tsid_11 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + + output = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", (input, tsid_10, tsid_11), (262144, 1), dtype="int32") + return output +# fmt: on + + +def test_parallel_serial_mixed_for_loops(): + target = Target("c") + global_ws_pool = WorkspacePoolInfo( + pool_name="global_workspace", + targets=[target], + ) + passes = [relax.transform.ToNonDataflow(), relax.transform.CallTIRRewrite()] + seq = tvm.transform.Sequential(passes) + + all_serial_mod = AllSerialForLoops + all_serial_mod = seq(all_serial_mod) + all_serial_mod = _assign_targets_to_relaxfuncs_irmodule(all_serial_mod, target) + all_serial_mod = _assign_poolinfos_to_allocates_in_irmodule(all_serial_mod, [global_ws_pool]) + main_func = all_serial_mod["run_model"] + buffer_info_analysis = tvm.relax.analysis.extract_buffer_info(main_func, all_serial_mod) + assert buffer_info_analysis.memory_pressure == 1479426 + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts) + + # When all loops are serial all allocates are touched by USMP + assert len(buffer_info_map) == 6 + for name, _ in buffer_info_map.items(): + assert name in [ + "alloc", + "tsid_10", + "tsid_11", + "dummy_allocate", + "Conv2dOutput_8", + "PaddedInput_8", + ] + + parallel_serial_mixed_tir_mod = ParallelSerialMixedForLoops + parallel_serial_mixed_tir_mod = seq(parallel_serial_mixed_tir_mod) + parallel_serial_mixed_tir_mod = _assign_targets_to_relaxfuncs_irmodule( + parallel_serial_mixed_tir_mod, target + ) + parallel_serial_mixed_tir_mod = _assign_poolinfos_to_allocates_in_irmodule( + parallel_serial_mixed_tir_mod, [global_ws_pool] + ) + main_func = parallel_serial_mixed_tir_mod["run_model"] + buffer_info_analysis = tvm.relax.analysis.extract_buffer_info( + main_func, parallel_serial_mixed_tir_mod + ) + assert buffer_info_analysis.memory_pressure == 1479426 + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts) + + # USMP will not touch (yet) the allocates inside parallel for loops + assert len(buffer_info_map) == 5 + for name, _ in buffer_info_map.items(): + assert name in ["alloc", "tsid_10", "tsid_11", "Conv2dOutput_8", "PaddedInput_8"] + + +# fmt: off +@tvm.script.ir_module +class InceptionStructure: + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d(placeholder: T.handle, tensor: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d", "tir.noalias": True}) + placeholder_1 = T.match_buffer(placeholder, [602112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + tensor_1 = T.match_buffer(tensor, [249], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused in T.serial(0, 28): + for ax2 in T.serial(0, 28): + for ax3_outer_init, ax3_inner_init in T.grid(3, 64): + tensor_1[((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer_init*64)) + ax3_inner_init)] = T.uint8(0) + for rv0_rv1_fused, ax3_outer, ax3_inner in T.grid(9, 3, 64): + tensor_1[((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)] = T.max(tensor_1[((((ax0_ax1_fused*5376) + (ax2*192)) + (ax3_outer*64)) + ax3_inner)], T.if_then_else(((((ax0_ax1_fused*2) + T.floordiv(rv0_rv1_fused, 3)) < 56) and (((ax2*2) + T.floormod(rv0_rv1_fused, 3)) < 56)), placeholder_1[((((((ax0_ax1_fused*21504) + (T.floordiv(rv0_rv1_fused, 3)*10752)) + (ax2*384)) + (T.floormod(rv0_rv1_fused, 3)*192)) + (ax3_outer*64)) + ax3_inner)], T.uint8(0), dtype="uint8")) + + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0]) + + @T.prim_func + def tvmgen_default_fused_cast(placeholder_6: T.handle, T_cast: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast", "tir.noalias": True}) + placeholder_7 = T.match_buffer(placeholder_6, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_1 = T.match_buffer(T_cast, [249], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_2 in T.serial(0, 28): + for ax2_2, ax3_outer_1, ax3_inner_2 in T.grid(28, 12, 16): + T_cast_1[((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)] = T.cast(placeholder_7[((((ax0_ax1_fused_2*5376) + (ax2_2*192)) + (ax3_outer_1*16)) + ax3_inner_2)], "int16") + + @T.prim_func + def tvmgen_default_fused_concatenate(placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, placeholder_11: T.handle, T_concat: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_concatenate", "tir.noalias": True}) + placeholder_12 = T.match_buffer(placeholder_8, [50176], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_concat_1 = T.match_buffer(T_concat, [313], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_13 = T.match_buffer(placeholder_9, [100352], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_14 = T.match_buffer(placeholder_11, [25088], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_15 = T.match_buffer(placeholder_10, [25088], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_3 in T.serial(0, 28): + for ax2_3, ax3 in T.grid(28, 256): + T_concat_1[(((ax0_ax1_fused_3*7168) + (ax2_3*256)) + ax3)] = T.if_then_else((224 <= ax3), placeholder_14[((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 224)], T.if_then_else((192 <= ax3), placeholder_15[((((ax0_ax1_fused_3*896) + (ax2_3*32)) + ax3) - 192)], T.if_then_else((64 <= ax3), placeholder_13[((((ax0_ax1_fused_3*3584) + (ax2_3*128)) + ax3) - 64)], placeholder_12[(((ax0_ax1_fused_3*1792) + (ax2_3*64)) + ax3)], dtype="uint8"), dtype="uint8"), dtype="uint8") + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_cast_2: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) + placeholder_19 = T.match_buffer(placeholder_16, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_20 = T.match_buffer(placeholder_17, [4096], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_21 = T.match_buffer(placeholder_18, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_3 = T.match_buffer(T_cast_2, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput = T.allocate([200704], "int16", "global") + for i0_i1_fused in T.serial(0, 56): + for i2, i3 in T.grid(56, 64): + PaddedInput[(((i0_i1_fused*3584) + (i2*64)) + i3)] = placeholder_19[(((i0_i1_fused*3584) + (i2*64)) + i3)] + for ax0_ax1_fused_ax2_fused in T.serial(0, 3136): + Conv2dOutput = T.allocate([64], "int32", "global") + for ff in T.serial(0, 64): + Conv2dOutput[ff] = 0 + for rc in T.serial(0, 64): + Conv2dOutput[ff] = (Conv2dOutput[ff] + (T.cast(PaddedInput[((ax0_ax1_fused_ax2_fused*64) + rc)], "int32")*T.cast(placeholder_20[((rc*64) + ff)], "int32"))) + for ax3_inner_3 in T.serial(0, 64): + T_cast_3[((ax0_ax1_fused_ax2_fused*64) + ax3_inner_3)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput[ax3_inner_3] + placeholder_21[ax3_inner_3]), 1191576922, 31, -4, dtype="int32"), 255), 0), "uint8"), "int16") + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, T_cast_4: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) + placeholder_25 = T.match_buffer(placeholder_22, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_26 = T.match_buffer(placeholder_23, [18432], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_27 = T.match_buffer(placeholder_24, [96], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_5 = T.match_buffer(T_cast_4, [153], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_1 = T.allocate([150528], "int16", "global") + for i0_i1_fused_1 in T.serial(0, 28): + for i2_1, i3_1 in T.grid(28, 192): + PaddedInput_1[(((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)] = placeholder_25[(((i0_i1_fused_1*5376) + (i2_1*192)) + i3_1)] + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 784): + Conv2dOutput_1 = T.allocate([1], "int32", "global") + for ax3_1 in T.serial(0, 96): + Conv2dOutput_1[0] = 0 + for rc_1 in T.serial(0, 192): + Conv2dOutput_1[0] = (Conv2dOutput_1[0] + (T.cast(PaddedInput_1[((ax0_ax1_fused_ax2_fused_1*192) + rc_1)], "int32")*T.cast(placeholder_26[((rc_1*96) + ax3_1)], "int32"))) + T_cast_5[((ax0_ax1_fused_ax2_fused_1*96) + ax3_1)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_1[0] + placeholder_27[ax3_1]), 1201322342, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16") + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [177], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init)] = T.uint8(0) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)] = T.max(tensor_2[(((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)], T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), placeholder_29[(((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)], T.uint8(0), dtype="uint8")) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T_cast_7[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)] = T.cast(tensor_2[(((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)], "int16") + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", "tir.noalias": True}) + placeholder_33 = T.match_buffer(placeholder_30, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = T.match_buffer(placeholder_31, [12288], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = T.match_buffer(placeholder_32, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = T.match_buffer(T_cast_8, [121], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_2 = T.allocate([150528], "int16", "global") + for i0_i1_fused_2 in T.serial(0, 28): + for i2_2, i3_2 in T.grid(28, 192): + PaddedInput_2[(((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)] = placeholder_33[(((i0_i1_fused_2*5376) + (i2_2*192)) + i3_2)] + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 784): + Conv2dOutput_2 = T.allocate([64], "int32", "global") + for ff_1 in T.serial(0, 64): + Conv2dOutput_2[ff_1] = 0 + for rc_2 in T.serial(0, 192): + Conv2dOutput_2[ff_1] = (Conv2dOutput_2[ff_1] + (T.cast(PaddedInput_2[((ax0_ax1_fused_ax2_fused_2*192) + rc_2)], "int32")*T.cast(placeholder_34[((rc_2*64) + ff_1)], "int32"))) + for ax3_inner_4 in T.serial(0, 64): + T_cast_9[((ax0_ax1_fused_ax2_fused_2*64) + ax3_inner_4)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_2[ax3_inner_4] + placeholder_35[ax3_inner_4]), 1663316467, 31, -7, dtype="int32"), 255), 0), "uint8") + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast_1(placeholder_36: T.handle, T_cast_10: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast_1", "tir.noalias": True}) + placeholder_37 = T.match_buffer(placeholder_36, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_11 = T.match_buffer(T_cast_10, [249], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_3 = T.allocate([150528], "uint8", "global") + for ax0_ax1_fused_6 in T.serial(0, 28): + for ax2_6 in T.serial(0, 28): + for ax3_outer_init_1, ax3_inner_init_1 in T.grid(3, 64): + tensor_3[((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_init_1*64)) + ax3_inner_init_1)] = T.uint8(0) + for rv0_rv1_fused_2, ax3_outer_2, ax3_inner_5 in T.grid(9, 3, 64): + tensor_3[((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)] = T.max(tensor_3[((((ax0_ax1_fused_6*5376) + (ax2_6*192)) + (ax3_outer_2*64)) + ax3_inner_5)], T.if_then_else(((((1 <= (T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6)) and ((T.floordiv(rv0_rv1_fused_2, 3) + ax0_ax1_fused_6) < 29)) and (1 <= (ax2_6 + T.floormod(rv0_rv1_fused_2, 3)))) and ((ax2_6 + T.floormod(rv0_rv1_fused_2, 3)) < 29)), placeholder_37[(((((((T.floordiv(rv0_rv1_fused_2, 3)*5376) + (ax0_ax1_fused_6*5376)) + (ax2_6*192)) + (T.floormod(rv0_rv1_fused_2, 3)*192)) + (ax3_outer_2*64)) + ax3_inner_5) - 5568)], T.uint8(0), dtype="uint8")) + for ax0_ax1_fused_7 in T.serial(0, 28): + for ax2_7, ax3_4 in T.grid(28, 192): + T_cast_11[(((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)] = T.cast(tensor_3[(((ax0_ax1_fused_7*5376) + (ax2_7*192)) + ax3_4)], "int16") + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2(placeholder_38: T.handle, placeholder_39: T.handle, placeholder_40: T.handle, T_cast_12: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", "tir.noalias": True}) + placeholder_41 = T.match_buffer(placeholder_38, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_42 = T.match_buffer(placeholder_39, [6144], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_43 = T.match_buffer(placeholder_40, [32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_13 = T.match_buffer(T_cast_12, [89], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_3 = T.allocate([150528], "int16", "global") + for i0_i1_fused_3 in T.serial(0, 28): + for i2_3, i3_3 in T.grid(28, 192): + PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] = placeholder_41[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] + for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 784): + Conv2dOutput_3 = T.allocate([1], "int32", "global") + for ax3_5 in T.serial(0, 32): + Conv2dOutput_3[0] = 0 + for rc_3 in T.serial(0, 192): + Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_42[((rc_3*32) + ax3_5)], "int32"))) + T_cast_13[((ax0_ax1_fused_ax2_fused_3*32) + ax3_5)] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_43[ax3_5]), 1811141736, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1136333842, 31, 0, dtype="int32"), 255), 0), "uint8") + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_44: T.handle, placeholder_45: T.handle, placeholder_46: T.handle, T_cast_14: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) + placeholder_47 = T.match_buffer(placeholder_44, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_48 = T.match_buffer(placeholder_45, [3072], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_49 = T.match_buffer(placeholder_46, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_15 = T.match_buffer(T_cast_14, [73], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_4 = T.allocate([150528], "int16", "global") + for i0_i1_fused_4 in T.serial(0, 28): + for i2_4, i3_4 in T.grid(28, 192): + PaddedInput_4[(((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)] = placeholder_47[(((i0_i1_fused_4*5376) + (i2_4*192)) + i3_4)] + for ax0_ax1_fused_ax2_fused_4 in T.serial(0, 784): + Conv2dOutput_4 = T.allocate([1], "int32", "global") + for ax3_6 in T.serial(0, 16): + Conv2dOutput_4[0] = 0 + for rc_4 in T.serial(0, 192): + Conv2dOutput_4[0] = (Conv2dOutput_4[0] + (T.cast(PaddedInput_4[((ax0_ax1_fused_ax2_fused_4*192) + rc_4)], "int32")*T.cast(placeholder_48[((rc_4*16) + ax3_6)], "int32"))) + T_cast_15[((ax0_ax1_fused_ax2_fused_4*16) + ax3_6)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_4[0] + placeholder_49[ax3_6]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16") + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1(placeholder_50: T.handle, placeholder_51: T.handle, placeholder_52: T.handle, T_cast_16: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", "tir.noalias": True}) + placeholder_53 = T.match_buffer(placeholder_50, [12544], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_54 = T.match_buffer(placeholder_51, [4608], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_55 = T.match_buffer(placeholder_52, [32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_17 = T.match_buffer(T_cast_16, [89], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_5 = T.allocate([14400], "int16", "global") + for i0_i1_fused_5 in T.serial(0, 30): + for i2_5, i3_5 in T.grid(30, 16): + PaddedInput_5[(((i0_i1_fused_5*480) + (i2_5*16)) + i3_5)] = T.if_then_else(((((1 <= i0_i1_fused_5) and (i0_i1_fused_5 < 29)) and (1 <= i2_5)) and (i2_5 < 29)), placeholder_53[((((i0_i1_fused_5*448) + (i2_5*16)) + i3_5) - 464)], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_5 in T.serial(0, 784): + Conv2dOutput_5 = T.allocate([1], "int32", "global") + for ax3_7 in T.serial(0, 32): + Conv2dOutput_5[0] = 0 + for ry, rx, rc_5 in T.grid(3, 3, 16): + Conv2dOutput_5[0] = (Conv2dOutput_5[0] + (T.cast(PaddedInput_5[(((((T.floordiv(ax0_ax1_fused_ax2_fused_5, 28)*480) + (ry*480)) + (rx*16)) + (T.floormod(ax0_ax1_fused_ax2_fused_5, 28)*16)) + rc_5)], "int32")*T.cast(placeholder_54[((((ry*1536) + (rx*512)) + (rc_5*32)) + ax3_7)], "int32"))) + T_cast_17[((ax0_ax1_fused_ax2_fused_5*32) + ax3_7)] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_5[0] + placeholder_55[ax3_7]), 1131968888, 31, -6, dtype="int32"), 255), 0), "uint8"), "int32"), 1900719667, 31, 0, dtype="int32"), 255), 0), "uint8") + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_(placeholder_56: T.handle, placeholder_57: T.handle, placeholder_58: T.handle, T_cast_18: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", "tir.noalias": True}) + placeholder_59 = T.match_buffer(placeholder_56, [75264], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_60 = T.match_buffer(placeholder_57, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_61 = T.match_buffer(placeholder_58, [128], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_19 = T.match_buffer(T_cast_18, [185], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_6 = T.allocate([86400], "int16", "global") + for i0_i1_fused_6 in T.serial(0, 30): + for i2_6, i3_6 in T.grid(30, 96): + PaddedInput_6[(((i0_i1_fused_6*2880) + (i2_6*96)) + i3_6)] = T.if_then_else(((((1 <= i0_i1_fused_6) and (i0_i1_fused_6 < 29)) and (1 <= i2_6)) and (i2_6 < 29)), placeholder_59[((((i0_i1_fused_6*2688) + (i2_6*96)) + i3_6) - 2784)], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_6 in T.serial(0, 784): + Conv2dOutput_6 = T.allocate([64], "int32", "global") + for ax3_outer_3 in T.serial(0, 2): + for ff_2 in T.serial(0, 64): + Conv2dOutput_6[ff_2] = 0 + for ry_1, rx_1, rc_6 in T.grid(3, 3, 96): + Conv2dOutput_6[ff_2] = (Conv2dOutput_6[ff_2] + (T.cast(PaddedInput_6[(((((T.floordiv(ax0_ax1_fused_ax2_fused_6, 28)*2880) + (ry_1*2880)) + (rx_1*96)) + (T.floormod(ax0_ax1_fused_ax2_fused_6, 28)*96)) + rc_6)], "int32")*T.cast(placeholder_60[(((((ry_1*36864) + (rx_1*12288)) + (rc_6*128)) + (ax3_outer_3*64)) + ff_2)], "int32"))) + for ax3_inner_6 in T.serial(0, 64): + T_cast_19[(((ax0_ax1_fused_ax2_fused_6*128) + (ax3_outer_3*64)) + ax3_inner_6)] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_6[ax3_inner_6] + placeholder_61[((ax3_outer_3*64) + ax3_inner_6)]), 1374050734, 31, -7, dtype="int32"), 255), 0), "uint8"), "int32"), 1544713713, 31, 0, dtype="int32"), 255), 0), "uint8") + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "T.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [289], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + Conv2dOutput_7[ff_3] = 0 + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32"))) + for ax3_inner_7 in T.serial(0, 64): + T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8") + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placeholder_68: T.handle, placeholder_69: T.handle, placeholder_70: T.handle, T_cast_22: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", "tir.noalias": True}) + placeholder_71 = T.match_buffer(placeholder_68, [200704], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_72 = T.match_buffer(placeholder_69, [110592], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_73 = T.match_buffer(placeholder_70, [192], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_23 = T.match_buffer(T_cast_22, [305], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_8 = T.allocate([215296], "int16", "global") + for i0_i1_fused_8 in T.serial(0, 58): + for i2_8, i3_8 in T.grid(58, 64): + PaddedInput_8[(((i0_i1_fused_8*3712) + (i2_8*64)) + i3_8)] = T.if_then_else(((((1 <= i0_i1_fused_8) and (i0_i1_fused_8 < 57)) and (1 <= i2_8)) and (i2_8 < 57)), placeholder_71[((((i0_i1_fused_8*3584) + (i2_8*64)) + i3_8) - 3648)], T.int16(0), dtype="int16") + for ax0_ax1_fused_ax2_fused_8 in T.serial(0, 3136): + Conv2dOutput_8 = T.allocate([64], "int32", "global") + for ax3_outer_4 in T.serial(0, 3): + for ff_4 in T.serial(0, 64): + Conv2dOutput_8[ff_4] = 0 + for ry_3, rx_3, rc_8 in T.grid(3, 3, 64): + Conv2dOutput_8[ff_4] = (Conv2dOutput_8[ff_4] + (T.cast(PaddedInput_8[(((((T.floordiv(ax0_ax1_fused_ax2_fused_8, 56)*3712) + (ry_3*3712)) + (rx_3*64)) + (T.floormod(ax0_ax1_fused_ax2_fused_8, 56)*64)) + rc_8)], "int32")*T.cast(placeholder_72[(((((ry_3*36864) + (rx_3*12288)) + (rc_8*192)) + (ax3_outer_4*64)) + ff_4)], "int32"))) + for ax3_inner_8 in T.serial(0, 64): + T_cast_23[(((ax0_ax1_fused_ax2_fused_8*192) + (ax3_outer_4*64)) + ax3_inner_8)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_8[ax3_inner_8] + placeholder_73[((ax3_outer_4*64) + ax3_inner_8)]), 1139793473, 31, -6, dtype="int32"), 255), 0), "uint8") + + @R.function + def run_model(input: Tensor((16, 16), "uint8")) -> Tensor: + tsid_100 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + tsid_101 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + sid_9 = relax.call_tir("tvmgen_default_fused_cast_subtract", (input, tsid_100), (301056, 1), dtype="int32") + sid_8 = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", (sid_9, tsid_100, tsid_101), (802816, 1), dtype="int32") + sid_7 = relax.call_tir("tvmgen_default_fused_nn_max_pool2d_cast", (sid_8), (401408, 1), dtype="int32") + sid_6 = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", (sid_7, tsid_100, tsid_101), (401408, 1), dtype="int32") + sid_5 = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1", (sid_6, tsid_100, tsid_101), (602112, 1), dtype="int32") + sid_4 = relax.call_tir("tvmgen_default_fused_nn_max_pool2d", (sid_5), (150528, 1), dtype="int32") + sid_3 = relax.call_tir("tvmgen_default_fused_cast", (sid_4), (301056, 1), dtype="int32") + sid_2 = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_2", (sid_3, tsid_100, tsid_101), (50176, 1), dtype="int32") + sid_20 = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", (sid_3, tsid_100, tsid_101), (150528, 1), dtype="int32") + sid_19 = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320_", (sid_20, tsid_100, tsid_101), (100352, 1), dtype="int32") + sid_26 = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", (sid_3, tsid_100, tsid_101), (25088, 1), dtype="int32") + sid_25 = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__1", (sid_26, tsid_100, tsid_101), (25088, 1), dtype="int32") + sid_32 = relax.call_tir("tvmgen_default_fused_nn_max_pool2d_cast_1", (sid_4), (301056, 1), dtype="int32") + sid_31 = relax.call_tir("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_fixed_point_multiply_cli_4464294615199028320__2", (sid_32, tsid_100, tsid_101), (25088, 1), dtype="int32") + output = relax.call_tir("tvmgen_default_fused_concatenate", (sid_2, sid_19, sid_25, sid_31), (25088, 1), dtype="int32") + return output +# fmt: on + + +def test_inception_structure(): + target = Target("c") + global_ws_pool = WorkspacePoolInfo( + pool_name="global_workspace", + targets=[target], + ) + relax_mod = InceptionStructure + passes = [relax.transform.ToNonDataflow(), relax.transform.CallTIRRewrite()] + seq = tvm.transform.Sequential(passes) + relax_mod = seq(relax_mod) + relax_mod = _assign_targets_to_relaxfuncs_irmodule(relax_mod, target) + relax_mod = _assign_poolinfos_to_allocates_in_irmodule(relax_mod, [global_ws_pool]) + main_func = relax_mod["run_model"] + buffer_info_analysis = tvm.relax.analysis.extract_buffer_info(main_func, relax_mod) + assert buffer_info_analysis.memory_pressure == 3526168 + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts) + + # check conflicts + _verify_conflicts( + "alloc9", + [ + "PaddedInput_6", + "alloc7", + "tsid_100", + "tsid_101", + "alloc5", + "alloc6", + "Conv2dOutput_6", + "PaddedInput_4", + "alloc10", + "Conv2dOutput_4", + "PaddedInput_5", + "Conv2dOutput_5", + "alloc11", + "tensor_3", + "alloc12", + "PaddedInput_3", + "Conv2dOutput_3", + "alloc13", + "alloc14", + ], + buffer_info_map, + ) + _verify_conflicts( + "tensor_3", + [ + "alloc11", + "alloc7", + "tsid_100", + "alloc9", + "tsid_101", + "alloc5", + "alloc12", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc1", + [ + "Conv2dOutput_7", + "PaddedInput_7", + "tsid_100", + "tsid_101", + "tensor_2", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_4", + [ + "alloc7", + "tsid_100", + "alloc9", + "tsid_101", + "alloc5", + "alloc6", + "alloc10", + "Conv2dOutput_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_8", + [ + "alloc3", + "tsid_101", + "tsid_100", + "Conv2dOutput_8", + "alloc4", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc8", + [ + "PaddedInput_1", + "alloc7", + "tsid_100", + "tsid_101", + "alloc5", + "alloc6", + "Conv2dOutput_1", + "PaddedInput_6", + ], + buffer_info_map, + ) + _verify_conflicts( + "tsid_101", + [ + "alloc1", + "Conv2dOutput_7", + "PaddedInput_7", + "tsid_100", + "tensor_2", + "alloc2", + "PaddedInput", + "Conv2dOutput", + "alloc3", + "PaddedInput_8", + "Conv2dOutput_8", + "alloc4", + "alloc5", + "alloc6", + "PaddedInput_2", + "Conv2dOutput_2", + "alloc7", + "PaddedInput_1", + "alloc8", + "Conv2dOutput_1", + "PaddedInput_6", + "alloc9", + "Conv2dOutput_6", + "PaddedInput_4", + "alloc10", + "Conv2dOutput_4", + "PaddedInput_5", + "Conv2dOutput_5", + "alloc11", + "tensor_3", + "alloc12", + "PaddedInput_3", + "Conv2dOutput_3", + "alloc13", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput", + [ + "PaddedInput", + "tsid_101", + "tsid_100", + "alloc3", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_3", + [ + "PaddedInput_3", + "alloc11", + "alloc7", + "tsid_100", + "alloc9", + "tsid_101", + "alloc13", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_7", + [ + "alloc", + "tsid_100", + "Conv2dOutput_7", + "alloc1", + "tsid_101", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc6", + [ + "alloc5", + "tsid_101", + "tsid_100", + "PaddedInput_2", + "Conv2dOutput_2", + "alloc7", + "PaddedInput_1", + "alloc8", + "Conv2dOutput_1", + "PaddedInput_6", + "alloc9", + "Conv2dOutput_6", + "PaddedInput_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_5", + [ + "alloc7", + "tsid_100", + "alloc9", + "tsid_101", + "PaddedInput_5", + "alloc5", + "alloc11", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc7", + [ + "Conv2dOutput_2", + "tsid_100", + "tsid_101", + "alloc5", + "alloc6", + "PaddedInput_2", + "PaddedInput_1", + "alloc8", + "Conv2dOutput_1", + "PaddedInput_6", + "alloc9", + "Conv2dOutput_6", + "PaddedInput_4", + "alloc10", + "Conv2dOutput_4", + "PaddedInput_5", + "Conv2dOutput_5", + "alloc11", + "tensor_3", + "alloc12", + "PaddedInput_3", + "Conv2dOutput_3", + "alloc13", + "alloc14", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_2", + [ + "alloc6", + "alloc5", + "tsid_101", + "tsid_100", + "Conv2dOutput_2", + "alloc7", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_1", + [ + "alloc7", + "tsid_100", + "tsid_101", + "alloc5", + "alloc6", + "alloc8", + "Conv2dOutput_1", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc11", + [ + "Conv2dOutput_5", + "alloc7", + "tsid_100", + "alloc9", + "tsid_101", + "PaddedInput_5", + "alloc5", + "tensor_3", + "alloc12", + "PaddedInput_3", + "Conv2dOutput_3", + "alloc13", + "alloc14", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc", + [ + "tsid_100", + "PaddedInput_7", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc4", + [ + "Conv2dOutput_8", + "PaddedInput_8", + "tsid_101", + "tsid_100", + "alloc5", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_6", + [ + "alloc7", + "tsid_100", + "alloc8", + "tsid_101", + "alloc5", + "alloc6", + "alloc9", + "Conv2dOutput_6", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc2", + [ + "tensor_2", + "tsid_101", + "tsid_100", + "PaddedInput", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc12", + [ + "tensor_3", + "alloc11", + "alloc7", + "tsid_100", + "alloc9", + "tsid_101", + "PaddedInput_3", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_8", + [ + "PaddedInput_8", + "tsid_101", + "tsid_100", + "alloc4", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc10", + [ + "PaddedInput_4", + "alloc7", + "tsid_100", + "alloc9", + "tsid_101", + "alloc5", + "Conv2dOutput_4", + "PaddedInput_5", + ], + buffer_info_map, + ) + _verify_conflicts( + "tensor_2", + [ + "alloc1", + "tsid_101", + "tsid_100", + "alloc2", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_2", + [ + "PaddedInput_2", + "alloc6", + "alloc5", + "tsid_101", + "tsid_100", + "alloc7", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_4", + [ + "alloc10", + "PaddedInput_4", + "alloc7", + "tsid_100", + "alloc9", + "tsid_101", + "alloc5", + ], + buffer_info_map, + ) + _verify_conflicts( + "tsid_100", + [ + "alloc", + "PaddedInput_7", + "Conv2dOutput_7", + "alloc1", + "tsid_101", + "tensor_2", + "alloc2", + "PaddedInput", + "Conv2dOutput", + "alloc3", + "PaddedInput_8", + "Conv2dOutput_8", + "alloc4", + "alloc5", + "alloc6", + "PaddedInput_2", + "Conv2dOutput_2", + "alloc7", + "PaddedInput_1", + "alloc8", + "Conv2dOutput_1", + "PaddedInput_6", + "alloc9", + "Conv2dOutput_6", + "PaddedInput_4", + "alloc10", + "Conv2dOutput_4", + "PaddedInput_5", + "Conv2dOutput_5", + "alloc11", + "tensor_3", + "alloc12", + "PaddedInput_3", + "Conv2dOutput_3", + "alloc13", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc3", + [ + "Conv2dOutput", + "PaddedInput", + "tsid_101", + "tsid_100", + "PaddedInput_8", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc5", + [ + "alloc4", + "tsid_101", + "tsid_100", + "alloc6", + "PaddedInput_2", + "Conv2dOutput_2", + "alloc7", + "PaddedInput_1", + "alloc8", + "Conv2dOutput_1", + "PaddedInput_6", + "alloc9", + "Conv2dOutput_6", + "PaddedInput_4", + "alloc10", + "Conv2dOutput_4", + "PaddedInput_5", + "Conv2dOutput_5", + "alloc11", + "tensor_3", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc14", + [ + "alloc13", + "alloc11", + "alloc7", + "alloc9", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput", + [ + "alloc2", + "tsid_101", + "tsid_100", + "Conv2dOutput", + "alloc3", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_3", + [ + "alloc12", + "alloc11", + "alloc7", + "tsid_100", + "alloc9", + "tsid_101", + "Conv2dOutput_3", + "alloc13", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_1", + [ + "PaddedInput_1", + "alloc7", + "tsid_100", + "alloc8", + "tsid_101", + "alloc5", + "alloc6", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_6", + [ + "PaddedInput_6", + "alloc7", + "tsid_100", + "alloc9", + "tsid_101", + "alloc5", + "alloc6", + ], + buffer_info_map, + ) + _verify_conflicts( + "PaddedInput_5", + [ + "alloc10", + "alloc7", + "tsid_100", + "alloc9", + "tsid_101", + "alloc5", + "Conv2dOutput_5", + "alloc11", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc13", + [ + "PaddedInput_3", + "alloc11", + "alloc7", + "tsid_100", + "Conv2dOutput_3", + "alloc9", + "tsid_101", + "alloc14", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_7", + [ + "PaddedInput_7", + "tsid_100", + "alloc1", + "tsid_101", + ], + buffer_info_map, + ) + + # check sizes + assert buffer_info_map["Conv2dOutput"].size_bytes == 256 + assert buffer_info_map["PaddedInput_7"].size_bytes == 314646 + assert buffer_info_map["PaddedInput_6"].size_bytes == 172800 + assert buffer_info_map["alloc9"].size_bytes == 401408 + assert buffer_info_map["PaddedInput_2"].size_bytes == 301056 + assert buffer_info_map["alloc11"].size_bytes == 100352 + assert buffer_info_map["alloc6"].size_bytes == 1204224 + assert buffer_info_map["Conv2dOutput_4"].size_bytes == 4 + assert buffer_info_map["alloc2"].size_bytes == 1605632 + assert buffer_info_map["alloc13"].size_bytes == 100352 + assert buffer_info_map["alloc"].size_bytes == 1204224 + assert buffer_info_map["PaddedInput_8"].size_bytes == 430592 + assert buffer_info_map["tsid_100"].size_bytes == 1 + assert buffer_info_map["Conv2dOutput_2"].size_bytes == 256 + assert buffer_info_map["PaddedInput_3"].size_bytes == 301056 + assert buffer_info_map["tensor_3"].size_bytes == 150528 + assert buffer_info_map["Conv2dOutput_5"].size_bytes == 4 + assert buffer_info_map["Conv2dOutput_7"].size_bytes == 256 + assert buffer_info_map["PaddedInput_1"].size_bytes == 301056 + assert buffer_info_map["Conv2dOutput_6"].size_bytes == 256 + assert buffer_info_map["PaddedInput"].size_bytes == 401408 + assert buffer_info_map["alloc12"].size_bytes == 1204224 + assert buffer_info_map["alloc5"].size_bytes == 602112 + assert buffer_info_map["tensor_2"].size_bytes == 200704 + assert buffer_info_map["alloc10"].size_bytes == 100352 + assert buffer_info_map["alloc7"].size_bytes == 200704 + assert buffer_info_map["alloc3"].size_bytes == 1605632 + assert buffer_info_map["Conv2dOutput_8"].size_bytes == 256 + assert buffer_info_map["Conv2dOutput_3"].size_bytes == 4 + assert buffer_info_map["alloc8"].size_bytes == 602112 + assert buffer_info_map["tsid_101"].size_bytes == 1 + assert buffer_info_map["Conv2dOutput_1"].size_bytes == 4 + assert buffer_info_map["alloc4"].size_bytes == 2408448 + assert buffer_info_map["alloc1"].size_bytes == 3211264 + assert buffer_info_map["alloc14"].size_bytes == 100352 + assert buffer_info_map["PaddedInput_4"].size_bytes == 301056 + assert buffer_info_map["PaddedInput_5"].size_bytes == 28800 + + +# fmt: off +@tvm.script.ir_module +class MultipleCallsToSamePrimFuncModule: + @T.prim_func + def tvmgen_default_fused_layout_transform_1(placeholder: T.handle, T_layout_trans: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_layout_transform_1", "tir.noalias": True}) + placeholder_1 = T.match_buffer(placeholder, [864], dtype="float32") + T_layout_trans_1 = T.match_buffer(T_layout_trans, [41], dtype="float32") + # body + for ax0_ax1_fused_ax2_fused, ax3, ax4_inner in T.grid(24, 12, 3): + T_layout_trans_1[ax0_ax1_fused_ax2_fused * 36 + ax3 * 3 + ax4_inner] = placeholder_1[ax4_inner * 288 + ax0_ax1_fused_ax2_fused * 12 + ax3] + + @T.prim_func + def tvmgen_default_fused_nn_contrib_conv2d_NCHWc(placeholder_2: T.handle, placeholder_3: T.handle, conv2d_NCHWc: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_contrib_conv2d_NCHWc", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [864], dtype="float32") + placeholder_5 = T.match_buffer(placeholder_3, [81], dtype="float32") + conv2d_NCHWc_1 = T.match_buffer(conv2d_NCHWc, [41], dtype="float32") + # body + data_pad = T.allocate([1092], "float32", "global") + for i0_i1_fused_i2_fused, i3, i4 in T.grid(26, 14, 3): + data_pad[i0_i1_fused_i2_fused * 42 + i3 * 3 + i4] = T.if_then_else(1 <= i0_i1_fused_i2_fused and i0_i1_fused_i2_fused < 25 and 1 <= i3 and i3 < 13, placeholder_4[i0_i1_fused_i2_fused * 36 + i3 * 3 + i4 - 39], T.float32(0), dtype="float32") + for n_oc_chunk_fused_oh_fused in T.serial(0, 24): + conv2d_NCHWc_global = T.allocate([36], "float32", "global") + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init] = T.float32(0) + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init + 3] = T.float32(0) + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init + 6] = T.float32(0) + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init + 9] = T.float32(0) + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init + 12] = T.float32(0) + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init + 15] = T.float32(0) + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init + 18] = T.float32(0) + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init + 21] = T.float32(0) + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init + 24] = T.float32(0) + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init + 27] = T.float32(0) + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init + 30] = T.float32(0) + for oc_block_c_init in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c_init + 33] = T.float32(0) + for kh, kw, ic_inner in T.grid(3, 3, 3): + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c] = conv2d_NCHWc_global[oc_block_c] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c + 3] = conv2d_NCHWc_global[oc_block_c + 3] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 3] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c + 6] = conv2d_NCHWc_global[oc_block_c + 6] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 6] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c + 9] = conv2d_NCHWc_global[oc_block_c + 9] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 9] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c + 12] = conv2d_NCHWc_global[oc_block_c + 12] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 12] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c + 15] = conv2d_NCHWc_global[oc_block_c + 15] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 15] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c + 18] = conv2d_NCHWc_global[oc_block_c + 18] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 18] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c + 21] = conv2d_NCHWc_global[oc_block_c + 21] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 21] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c + 24] = conv2d_NCHWc_global[oc_block_c + 24] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 24] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c + 27] = conv2d_NCHWc_global[oc_block_c + 27] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 27] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c + 30] = conv2d_NCHWc_global[oc_block_c + 30] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 30] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for oc_block_c in T.serial(0, 3): + conv2d_NCHWc_global[oc_block_c + 33] = conv2d_NCHWc_global[oc_block_c + 33] + data_pad[kh * 42 + n_oc_chunk_fused_oh_fused * 42 + kw * 3 + ic_inner + 33] * placeholder_5[kh * 27 + kw * 9 + ic_inner * 3 + oc_block_c] + for ow_inner, oc_block in T.grid(12, 3): + conv2d_NCHWc_1[n_oc_chunk_fused_oh_fused * 36 + ow_inner * 3 + oc_block] = conv2d_NCHWc_global[ow_inner * 3 + oc_block] + + @T.prim_func + def tvmgen_default_fused_nn_softmax_add_add_multiply_add(placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle, T_add: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_softmax_add_add_multiply_add", "tir.noalias": True}) + placeholder_11 = T.match_buffer(placeholder_6, [864], dtype="float32") + placeholder_12 = T.match_buffer(placeholder_7, [864], dtype="float32") + placeholder_13 = T.match_buffer(placeholder_8, [3], dtype="float32") + placeholder_14 = T.match_buffer(placeholder_9, [3], dtype="float32") + placeholder_15 = T.match_buffer(placeholder_10, [3], dtype="float32") + T_add_1 = T.match_buffer(T_add, [864], dtype="float32") + # body + for ax0_ax1_fused_ax2_fused in T.serial(0, 72): + T_softmax_norm = T.allocate([12], "float32", "global") + with T.allocate([1], "float32", "global") as T_softmax_maxelem: + T_softmax_maxelem[0] = T.float32(-3.4028234663852886e+38) + for k in T.serial(0, 12): + T_softmax_maxelem[0] = T.max(T_softmax_maxelem[0], placeholder_11[ax0_ax1_fused_ax2_fused * 12 + k]) + T_softmax_exp = T.allocate([12], "float32", "global") + for i3 in T.serial(0, 12): + T_softmax_exp[i3] = T.exp(placeholder_11[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32") + T_softmax_expsum = T.allocate([1], "float32", "global") + T_softmax_expsum[0] = T.float32(0) + for k in T.serial(0, 12): + T_softmax_expsum[0] = T_softmax_expsum[0] + T_softmax_exp[k] + for i3 in T.serial(0, 12): + T_softmax_norm[i3] = T_softmax_exp[i3] / T_softmax_expsum[0] + for ax3 in T.serial(0, 12): + T_add_1[ax0_ax1_fused_ax2_fused * 12 + ax3] = (placeholder_12[ax0_ax1_fused_ax2_fused * 12 + ax3] + T_softmax_norm[ax3] + placeholder_13[T.floordiv(ax0_ax1_fused_ax2_fused, 24)]) * placeholder_14[T.floordiv(ax0_ax1_fused_ax2_fused, 24)] + placeholder_15[T.floordiv(ax0_ax1_fused_ax2_fused, 24)] + + @T.prim_func + def tvmgen_default_fused_nn_contrib_dense_pack_nn_relu(placeholder_16: T.handle, placeholder_17: T.handle, T_relu: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", "tir.noalias": True}) + placeholder_18 = T.match_buffer(placeholder_16, [864], dtype="float32") + placeholder_19 = T.match_buffer(placeholder_17, [144], dtype="float32") + T_relu_1 = T.match_buffer(T_relu, [864], dtype="float32") + # body + for ax1_outer_ax0_outer_fused in T.serial(0, 18): + compute = T.allocate([48], "float32", "global") + with T.allocate([48], "float32", "global") as compute_global: + for x_c_init in T.serial(0, 6): + compute_global[x_c_init] = T.float32(0) + for x_c_init in T.serial(0, 6): + compute_global[x_c_init + 6] = T.float32(0) + for x_c_init in T.serial(0, 6): + compute_global[x_c_init + 12] = T.float32(0) + for x_c_init in T.serial(0, 6): + compute_global[x_c_init + 18] = T.float32(0) + for x_c_init in T.serial(0, 6): + compute_global[x_c_init + 24] = T.float32(0) + for x_c_init in T.serial(0, 6): + compute_global[x_c_init + 30] = T.float32(0) + for x_c_init in T.serial(0, 6): + compute_global[x_c_init + 36] = T.float32(0) + for x_c_init in T.serial(0, 6): + compute_global[x_c_init + 42] = T.float32(0) + for k_outer in T.serial(0, 12): + for x_c in T.serial(0, 6): + compute_global[x_c] = compute_global[x_c] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] + for x_c in T.serial(0, 6): + compute_global[x_c + 6] = compute_global[x_c + 6] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 12] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] + for x_c in T.serial(0, 6): + compute_global[x_c + 12] = compute_global[x_c + 12] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 24] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] + for x_c in T.serial(0, 6): + compute_global[x_c + 18] = compute_global[x_c + 18] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 36] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] + for x_c in T.serial(0, 6): + compute_global[x_c + 24] = compute_global[x_c + 24] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 48] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] + for x_c in T.serial(0, 6): + compute_global[x_c + 30] = compute_global[x_c + 30] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 60] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] + for x_c in T.serial(0, 6): + compute_global[x_c + 36] = compute_global[x_c + 36] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 72] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] + for x_c in T.serial(0, 6): + compute_global[x_c + 42] = compute_global[x_c + 42] + placeholder_18[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + k_outer + 84] * placeholder_19[T.floordiv(ax1_outer_ax0_outer_fused, 9) * 72 + k_outer * 6 + x_c] + for x_inner_inner in T.serial(0, 6): + compute[x_inner_inner] = compute_global[x_inner_inner] + for x_inner_inner in T.serial(0, 6): + compute[x_inner_inner + 6] = compute_global[x_inner_inner + 6] + for x_inner_inner in T.serial(0, 6): + compute[x_inner_inner + 12] = compute_global[x_inner_inner + 12] + for x_inner_inner in T.serial(0, 6): + compute[x_inner_inner + 18] = compute_global[x_inner_inner + 18] + for x_inner_inner in T.serial(0, 6): + compute[x_inner_inner + 24] = compute_global[x_inner_inner + 24] + for x_inner_inner in T.serial(0, 6): + compute[x_inner_inner + 30] = compute_global[x_inner_inner + 30] + for x_inner_inner in T.serial(0, 6): + compute[x_inner_inner + 36] = compute_global[x_inner_inner + 36] + for x_inner_inner in T.serial(0, 6): + compute[x_inner_inner + 42] = compute_global[x_inner_inner + 42] + for ax0_inner_inner, ax1_inner_inner in T.grid(8, 6): + T_relu_1[T.floormod(ax1_outer_ax0_outer_fused, 9) * 96 + ax0_inner_inner * 12 + T.floordiv(ax1_outer_ax0_outer_fused, 9) * 6 + ax1_inner_inner] = T.max(compute[ax0_inner_inner * 6 + ax1_inner_inner], T.float32(0)) + + @T.prim_func + def tvmgen_default_fused_reshape_1(placeholder_20: T.handle, T_reshape: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_reshape_1", "tir.noalias": True}) + placeholder_21 = T.match_buffer(placeholder_20, [864], dtype="float32") + T_reshape_1 = T.match_buffer(T_reshape, [864], dtype="float32") + # body + for ax0, ax1_inner in T.grid(72, 12): + T_reshape_1[ax0 * 12 + ax1_inner] = placeholder_21[ax0 * 12 + ax1_inner] + + @T.prim_func + def tvmgen_default_fused_layout_transform(placeholder_22: T.handle, T_layout_trans_2: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_layout_transform", "tir.noalias": True}) + placeholder_23 = T.match_buffer(placeholder_22, [864], dtype="float32") + T_layout_trans_3 = T.match_buffer(T_layout_trans_2, [864], dtype="float32") + # body + for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): + T_layout_trans_3[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] = placeholder_23[ax2 * 36 + ax3_inner * 3 + ax0_ax1_fused] + + @T.prim_func + def tvmgen_default_fused_reshape(placeholder_24: T.handle, T_reshape_2: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_reshape", "tir.noalias": True}) + placeholder_25 = T.match_buffer(placeholder_24, [864], dtype="float32") + T_reshape_3 = T.match_buffer(T_reshape_2, [864], dtype="float32") + # body + for ax0_ax1_fused, ax2, ax3_inner in T.grid(3, 24, 12): + T_reshape_3[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] = placeholder_25[ax0_ax1_fused * 288 + ax2 * 12 + ax3_inner] + + @T.prim_func + def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27: T.handle, T_add_2: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "tvmgen_default_fused_nn_softmax_add", "tir.noalias": True}) + placeholder_28 = T.match_buffer(placeholder_26, [864], dtype="float32") + placeholder_29 = T.match_buffer(placeholder_27, [864], dtype="float32") + T_add_3 = T.match_buffer(T_add_2, [864], dtype="float32") + # body + for ax0_ax1_fused_ax2_fused in T.serial(0, 72): + T_softmax_norm = T.allocate([12], "float32", "global") + with T.allocate([1], "float32", "global") as T_softmax_maxelem: + T_softmax_maxelem[0] = T.float32(-3.4028234663852886e+38) + for k in T.serial(0, 12): + T_softmax_maxelem[0] = T.max(T_softmax_maxelem[0], placeholder_28[ax0_ax1_fused_ax2_fused * 12 + k]) + T_softmax_exp = T.allocate([12], "float32", "global") + for i3 in T.serial(0, 12): + T_softmax_exp[i3] = T.exp(placeholder_28[ax0_ax1_fused_ax2_fused * 12 + i3] - T_softmax_maxelem[0], dtype="float32") + T_softmax_expsum = T.allocate([1], "float32", "global") + T_softmax_expsum[0] = T.float32(0) + for k in T.serial(0, 12): + T_softmax_expsum[0] = T_softmax_expsum[0] + T_softmax_exp[k] + for i3 in T.serial(0, 12): + T_softmax_norm[i3] = T_softmax_exp[i3] / T_softmax_expsum[0] + for ax3 in T.serial(0, 12): + T_add_3[ax0_ax1_fused_ax2_fused * 12 + ax3] = placeholder_29[ax0_ax1_fused_ax2_fused * 12 + ax3] + T_softmax_norm[ax3] + + @R.function + def run_model(input: Tensor((16, 16), "uint8")) -> Tensor: + tsid_100 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + tsid_101 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + tsid_102 = relax.builtin.alloc_tensor((1, 1), runtime_device_index=0, dtype="int8") + sid_8 = relax.builtin.alloc_tensor((3456, 1), runtime_device_index=0, dtype="int8") + + sid_23 = relax.call_tir("tvmgen_default_fused_layout_transform_1", (input), (2, 1), dtype="int32") + sid_7 = relax.call_tir("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", (sid_8, tsid_100), (3456, 1), dtype="int32") + sid_6 = relax.call_tir("tvmgen_default_fused_layout_transform", (sid_7), (3456, 1), dtype="int32") + sid_12 = relax.call_tir("tvmgen_default_fused_reshape_1", (input), (3456, 1), dtype="int32") + sid_11 = relax.call_tir("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", (sid_12, tsid_100), (3456, 1), dtype="int32") + sid_10 = relax.call_tir("tvmgen_default_fused_reshape", (sid_11), (3456, 1), dtype="int32") + sid_5 = relax.call_tir("tvmgen_default_fused_nn_softmax_add_add_multiply_add", (sid_6, sid_10, tsid_100, tsid_101, tsid_102), (3456, 1), dtype="int32") + sid_4 = relax.call_tir("tvmgen_default_fused_layout_transform_1", (sid_5), (3456, 1), dtype="int32") + sid_3 = relax.call_tir("tvmgen_default_fused_nn_contrib_conv2d_NCHWc", (sid_4, tsid_100), (3456, 1), dtype="int32") + sid_2 = relax.call_tir("tvmgen_default_fused_layout_transform", (sid_3), (3456, 1), dtype="int32") + sid_20 = relax.call_tir("tvmgen_default_fused_reshape_1", (sid_5), (3456, 1), dtype="int32") + sid_19 = relax.call_tir("tvmgen_default_fused_nn_contrib_dense_pack_nn_relu", (sid_20, tsid_100), (3456, 1), dtype="int32") + sid_18 = relax.call_tir("tvmgen_default_fused_reshape", (sid_19), (3456, 1), dtype="int32") + output = relax.call_tir("tvmgen_default_fused_nn_softmax_add", (sid_2, sid_18), (3456, 1), dtype="int32") + return output +# fmt: on + + +def test_multiple_calls_to_same_primfunc(): + target = Target("c") + global_ws_pool = WorkspacePoolInfo( + pool_name="global_workspace", + targets=[target], + ) + global_const_pool = ConstantPoolInfo( + pool_name="global_constants", + targets=[target], + ) + + relax_mod = MultipleCallsToSamePrimFuncModule + passes = [relax.transform.ToNonDataflow(), relax.transform.CallTIRRewrite()] + seq = tvm.transform.Sequential(passes) + relax_mod = seq(relax_mod) + relax_mod = _assign_targets_to_relaxfuncs_irmodule(relax_mod, target) + relax_mod = _assign_poolinfos_to_allocates_in_irmodule( + relax_mod, [global_ws_pool], [global_const_pool] + ) + main_func = relax_mod["run_model"] + buffer_info_analysis = tvm.relax.analysis.extract_buffer_info(main_func, relax_mod) + assert buffer_info_analysis.memory_pressure == 41857 + buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts) + + # check conflicts + _verify_conflicts( + "T_softmax_norm", + [ + "alloc5", + "tsid_100", + "alloc2", + "T_softmax_maxelem", + "T_softmax_expsum", + "T_softmax_exp", + "tsid_101", + "tsid_102", + "alloc6", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc10", + [ + "alloc9", + "alloc6", + "tsid_100", + "compute", + "compute_global", + "alloc11", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc11", + [ + "compute_global", + "compute", + "alloc9", + "alloc10", + "tsid_100", + "alloc12", + ], + buffer_info_map, + ) + _verify_conflicts( + "compute", + [ + "alloc4", + "compute_global", + "alloc3", + "alloc2", + "tsid_100", + "alloc9", + "alloc10", + "tsid_100", + "compute_global", + "alloc11", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc1", + [ + "tsid_100", + "data_pad", + "conv2d_NCHWc_global", + "alloc2", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc2", + [ + "alloc1", + "tsid_100", + "alloc3", + "alloc4", + "compute_global", + "compute", + "alloc5", + "T_softmax_norm", + "T_softmax_maxelem", + "T_softmax_expsum", + "T_softmax_exp", + "tsid_101", + "tsid_102", + "alloc6", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc8", + [ + "conv2d_NCHWc_global", + "data_pad", + "alloc6", + "tsid_100", + "alloc9", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc12", + [ + "alloc9", + "alloc11", + "T_softmax_expsum2", + "alloc13", + "T_softmax_norm2", + "T_softmax_exp2", + "T_softmax_maxelem2", + ], + buffer_info_map, + ) + _verify_conflicts( + "tsid_100", + [ + "conv2d_NCHWc_global", + "data_pad", + "alloc1", + "alloc2", + "alloc3", + "alloc4", + "compute_global", + "compute", + "alloc5", + "T_softmax_norm", + "T_softmax_maxelem", + "T_softmax_expsum", + "T_softmax_exp", + "tsid_101", + "tsid_102", + "alloc6", + "alloc7", + "data_pad", + "conv2d_NCHWc_global", + "alloc8", + "alloc9", + "alloc10", + "compute", + "compute_global", + "alloc11", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_maxelem2", + [ + "T_softmax_exp2", + "T_softmax_norm2", + "alloc13", + "T_softmax_expsum2", + "alloc12", + "alloc9", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_maxelem", + [ + "T_softmax_norm", + "alloc5", + "tsid_100", + "alloc2", + "T_softmax_expsum", + "T_softmax_exp", + "tsid_101", + "tsid_102", + "alloc6", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_norm2", + [ + "alloc13", + "T_softmax_expsum2", + "alloc12", + "alloc9", + "T_softmax_exp2", + "T_softmax_maxelem2", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_exp2", + [ + "T_softmax_norm2", + "alloc13", + "T_softmax_expsum2", + "alloc12", + "alloc9", + "T_softmax_maxelem2", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_expsum2", + [ + "alloc12", + "alloc9", + "alloc13", + "T_softmax_norm2", + "T_softmax_exp2", + "T_softmax_maxelem2", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc5", + [ + "tsid_100", + "alloc2", + "alloc4", + "T_softmax_norm", + "T_softmax_maxelem", + "T_softmax_expsum", + "T_softmax_exp", + "tsid_101", + "tsid_102", + "alloc6", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc", + [], + buffer_info_map, + ) + _verify_conflicts( + "sid_8", + [ + "data_pad", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc4", + [ + "alloc3", + "alloc2", + "tsid_100", + "compute_global", + "compute", + "alloc5", + ], + buffer_info_map, + ) + _verify_conflicts( + "compute_global", + [ + "alloc4", + "alloc3", + "alloc2", + "tsid_100", + "compute", + "compute", + "alloc9", + "alloc10", + "tsid_100", + "alloc11", + ], + buffer_info_map, + ) + _verify_conflicts( + "data_pad", + [ + "sid_8", + "conv2d_NCHWc_global", + "tsid_100", + "alloc1", + "alloc7", + "alloc6", + "tsid_100", + "conv2d_NCHWc_global", + "alloc8", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc3", + [ + "alloc2", + "tsid_100", + "alloc4", + "compute_global", + "compute", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_expsum", + [ + "T_softmax_maxelem", + "T_softmax_norm", + "alloc5", + "tsid_100", + "alloc2", + "T_softmax_exp", + "tsid_101", + "tsid_102", + "alloc6", + ], + buffer_info_map, + ) + _verify_conflicts( + "T_softmax_exp", + [ + "T_softmax_maxelem", + "T_softmax_norm", + "T_softmax_expsum", + "alloc5", + "tsid_100", + "alloc2", + "tsid_101", + "tsid_102", + "alloc6", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc7", + [ + "alloc6", + "tsid_100", + "data_pad", + ], + buffer_info_map, + ) + _verify_conflicts( + "tsid_101", + [ + "T_softmax_exp", + "T_softmax_maxelem", + "T_softmax_norm", + "T_softmax_expsum", + "alloc5", + "tsid_100", + "alloc2", + "tsid_102", + "alloc6", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc9", + [ + "alloc6", + "alloc8", + "tsid_100", + "alloc10", + "compute", + "compute_global", + "alloc11", + "alloc12", + "T_softmax_expsum2", + "alloc13", + "T_softmax_norm2", + "T_softmax_exp2", + "T_softmax_maxelem2", + ], + buffer_info_map, + ) + _verify_conflicts( + "conv2d_NCHWc_global", + [ + "data_pad", + "tsid_100", + "alloc1", + "data_pad", + "alloc6", + "tsid_100", + "alloc8", + ], + buffer_info_map, + ) + _verify_conflicts( + "tsid_102", + [ + "tsid_101", + "T_softmax_exp", + "T_softmax_maxelem", + "T_softmax_norm", + "T_softmax_expsum", + "alloc5", + "tsid_100", + "alloc2", + "alloc6", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc13", + [ + "T_softmax_expsum2", + "alloc12", + "alloc9", + "T_softmax_norm2", + "T_softmax_exp2", + "T_softmax_maxelem2", + ], + buffer_info_map, + ) + _verify_conflicts( + "alloc6", + [ + "tsid_102", + "tsid_101", + "T_softmax_exp", + "T_softmax_maxelem", + "T_softmax_norm", + "T_softmax_expsum", + "alloc5", + "tsid_100", + "alloc2", + "alloc7", + "data_pad", + "conv2d_NCHWc_global", + "alloc8", + "alloc9", + "alloc10", + ], + buffer_info_map, + ) + + +if __name__ == "__main__": + pytest.main([__file__] + sys.argv[1:]) diff --git a/tests/python/relax/test_relay_translator.py b/tests/python/relax/test_relay_translator.py new file mode 100644 index 0000000000..d071ffae63 --- /dev/null +++ b/tests/python/relax/test_relay_translator.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tempfile + +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import meta_schedule as ms +from tvm import relax, relay +from tvm.ir.base import assert_structural_equal +from tvm.relax.testing import relay_translator +from tvm.relay import testing +from tvm.runtime import vm +from tvm.script import tir as T +from tvm.target import Target + + +def get_resnet(batch_size, dtype, layout, image_shape): + relay_mod, params = testing.resnet.get_workload( + num_layers=18, + batch_size=batch_size, + dtype=dtype, + layout=layout, + image_shape=image_shape, + ) + + return relay_mod, params + + +def relay_build_and_run(mod, target, dev, params, data): + with tempfile.TemporaryDirectory() as work_dir: + db = ms.relay_integration.tune_relay( + mod=mod, + params=params, + target=target, + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=1024, + task_scheduler="round-robin", + work_dir=work_dir, + ) + ex = ms.relay_integration.compile_relay( + db, + mod=mod, + target=target, + params=params, + ) + rt_mod = tvm.contrib.graph_executor.GraphModule(ex["default"](dev)) + rt_mod.set_input("data", data) + rt_mod.run() + out = rt_mod.get_output(0).numpy() + return ex, rt_mod, out + + +def relax_build_and_run(mod, target, dev, params, data): + mod = relax.transform.BindParams("main", params)(mod) + with tempfile.TemporaryDirectory() as work_dir: + db = ms.relax_integration.tune_relax( + mod=mod, + target=target, + task_scheduler="round-robin", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=1024, + work_dir=work_dir, + ) + ex = ms.relax_integration.compile_relax( + db, + mod=mod, + target=target, + params=params, + ) + vm = relax.VirtualMachine(ex, dev) + res = vm["main"](data) + out = res.numpy() + return ex, vm, out + + +def verify_e2e_translation(target_str, layout, batch_size, image_shape): + target = Target(target_str) + dev = tvm.device(str(target), dev_id=0) + relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape) + input_shape = (1, *image_shape) + data = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32), dev) + relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) + _, _, relay_out = relay_build_and_run(relay_mod, target, dev, params, data) + _, _, relax_out = relax_build_and_run(relax_mod, target, dev, params, data) + tvm.testing.assert_allclose(relay_out, relax_out, atol=1e-5, rtol=1e-5) + + +@pytest.mark.skip(reason="take too much time") +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_e2e_translation_cpu(layout, batch_size, image_shape): + verify_e2e_translation("llvm --num-cores=16", layout, batch_size, image_shape) + + +@pytest.mark.skip(reason="take too much time") +@tvm.testing.requires_gpu +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_e2e_translation_gpu(layout, batch_size, image_shape): + verify_e2e_translation("cuda", layout, batch_size, image_shape) + + +def verify_extracted_tasks(target_str, layout, batch_size, image_shape): + target = Target(target_str) + relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape) + relax_mod = relay_translator.from_relay( + relay_mod["main"], + target, + params, + pass_config={ + "relay.backend.use_meta_schedule": True, + "relay.FuseOps.max_depth": 1, # Disable relay fusion + }, + ) + relay_tasks = ms.relay_integration.extract_tasks( + relay_mod, + target=target, + params=params, + pass_config={ + "relay.backend.use_meta_schedule": True, + "relay.FuseOps.max_depth": 1, # Disable relay fusion + }, + ) + relax_tasks = ms.relax_integration.extract_tasks( + relax_mod, + target=target, + params=params, + ) + # TODO (yongwww, yuchen): tophub guides relay passes, which causes inconsistent tasks + # assert len(relay_tasks) == len(relax_tasks) + # TODO: Can we compare extracted tasks as well? + + +@pytest.mark.parametrize( + "layout, batch_size, image_shape", + [ + ("NCHW", 1, (3, 224, 224)), + ("NHWC", 1, (224, 224, 3)), + ], +) +def test_verify_extracted_tasks_cpu(layout, batch_size, image_shape): + verify_extracted_tasks("llvm --num-cores=16", layout, batch_size, image_shape) + + +@tvm.testing.requires_gpu +@pytest.mark.parametrize( + "layout, batch_size, image_shape", [("NCHW", 1, (3, 224, 224)), ("NHWC", 1, (224, 224, 3))] +) +def test_verify_extracted_tasks_gpu(layout, batch_size, image_shape): + verify_extracted_tasks("cuda", layout, batch_size, image_shape) + + +def translate_and_build_vms(relay_mod, target_str="llvm", translate_op_with_tir=None): + target = tvm.target.Target(target_str) + + # build the relay IRModule and create relay vm + relay_ex = relay.vm.compile(relay_mod, target) + relay_vm = vm.VirtualMachine(relay_ex, tvm.cpu()) + + # build the relax IRModule and create relax vm + relax_mod = relay_translator.from_relay( + relay_mod["main"], target, translate_op_with_tir=translate_op_with_tir + ) + relax_ex = relax.vm.build(relax_mod, target) + relax_vm = relax.VirtualMachine(relax_ex, tvm.cpu()) + + return relay_vm, relax_vm, relax_mod + + +def verify_vm_outputs( + input_shape, + relay_vm, + relax_vm, + extra_args=[], +): + input = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32)) + + # check correctness by comparing relax and relay result + args = [input] + extra_args + relax_output = relax_vm["main"](*args) + relay_output = relay_vm.run(*args) + tvm.testing.assert_allclose(relay_output.numpy(), relax_output.numpy()) + + +def test_single_dynamic_dim(): + wx, wy = 64, 128 + # create relay module: y = data * weights + bias with dynamic batch dimension + data = relay.var("data", shape=(relay.Any(), wx)) + weights = relay.var("weights", shape=(wx, wy)) + bias = relay.var("bias", shape=(wy,)) + y = relay.nn.matmul(data, weights) + relay_mod = tvm.IRModule.from_expr(relay.Function([data, weights, bias], y + bias)) + + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + weights = tvm.nd.array(np.random.rand(wx, wy).astype(np.float32)) + bias = tvm.nd.array(np.random.rand(wy).astype(np.float32)) + # verify for different batch sizes + verify_vm_outputs([10, wx], relay_vm, relax_vm, [weights, bias]) + verify_vm_outputs([32, wx], relay_vm, relax_vm, [weights, bias]) + + +def test_multiple_dynamic_dims(): + # create relay module: y = a + a, where a has shape = (?, 5, ?) + shape = (relay.Any(), 5, relay.Any()) + a = relay.var("a", shape=shape) + + relay_mod = tvm.IRModule.from_expr(relay.Function([a], a + a)) + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + # verify for different shapes + verify_vm_outputs([2, 5, 10], relay_vm, relax_vm) + verify_vm_outputs([12, 5, 24], relay_vm, relax_vm) + + +def test_layout_transform(): + shape = (relay.Any(), 3, 224, 224) + a = relay.var("a", shape=shape) + b = relay.layout_transform(a, "NCHW", "NHWC") + relay_mod = tvm.IRModule.from_expr(relay.Function([a], b)) + + relay_vm, relax_vm, _ = translate_and_build_vms(relay_mod) + verify_vm_outputs([1, 3, 224, 224], relay_vm, relax_vm) + + +def test_translate_op_with_tir(): + @T.prim_func + def tir_matmul( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + C: T.Buffer[(512, 512), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "multiply", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1, i2 in T.grid(512, 512, 512): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(C[i, j], A[i, k], B[k, j]) + T.writes(C[i, j]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + shape = (512, 512) + a = relay.var("a", shape=shape) + + relay_mod = tvm.IRModule.from_expr(relay.Function([a], a * a)) + _, _, relax_mod = translate_and_build_vms( + relay_mod, translate_op_with_tir={"multiply": tir_matmul} + ) + assert_structural_equal(relax_mod["multiply"], tir_matmul) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_structual_equal_hash.py b/tests/python/relax/test_structual_equal_hash.py new file mode 100644 index 0000000000..d605e6a340 --- /dev/null +++ b/tests/python/relax/test_structual_equal_hash.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations # must import to defer parsing of annotations +import pytest +import sys +import tvm +from tvm import relax as rx, tir +from tvm.script import tir as T, relax as R + + +def _check_equal(x, y): + tvm.ir.assert_structural_equal(x, y) + tvm.ir.assert_structural_equal(y, x) + + xhash = tvm.ir.structural_hash(x) + yhash = tvm.ir.structural_hash(y) + + assert xhash == yhash + + +def _check_save_roundtrip(x): + y = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, y) + + +def test_var_binding(): + dtype = rx.DynTensorType(1) + x = rx.Var("x", [10], dtype) + y = rx.Var("y", [10], dtype) + + def generator(x, y): + bb = rx.BlockBuilder() + bb._begin_binding_block() + bb.emit(rx.op.add(x, y)) + return bb._end_block() + + block0 = generator(x, y) + block1 = generator(x, y) + + _check_equal(block0, block1) + + +def test_match_shape(): + dtype = rx.DynTensorType(1) + x = rx.Var("x", [10], dtype) + m = tir.Var("m", dtype="int32") + + def generator(x): + bb = rx.BlockBuilder() + bb._begin_binding_block() + bb.match_shape(x, [m * 2]) + return bb._end_block() + + block0 = generator(x) + block1 = generator(x) + + _check_equal(block0, block1) + + +def test_function(): + def generator(): + dtype = rx.DynTensorType(1, "float32") + x = rx.Var("x", [10], dtype) + y = rx.Var("y", [10], dtype) + bb = rx.BlockBuilder() + with bb.function("name", [x, y]): + gv = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv) + return bb.get() + + func0 = generator() + func1 = generator() + _check_equal(func0, func1) + + +def test_ir_module(): + def generator(): + dtype = rx.DynTensorType(1, "float32") + bb = rx.BlockBuilder() + x = rx.Var("x", [10], dtype) + y = rx.Var("y", [10], dtype) + with bb.function("test", [x, y]): + gv = bb.emit(rx.op.add(x, y)) + bb.emit_func_output(gv) + + # get global var + func_gv = bb.get().get_global_var("test") + + x = rx.Var("x", [10], dtype) + y = rx.Var("y", [10], dtype) + with bb.function("main", [x, y]): + gv = bb.emit(rx.Call(func_gv, [x, y])) + bb.emit_func_output(gv) + return bb.get() + + mod0 = generator() + mod1 = generator() + _check_equal(mod0, mod1) + + +def test_match_shape_symbolic(): + @tvm.script.ir_module + class InputModule: + @R.function + def f(x: Tensor((_, _), "float32")): + x0 = R.match_shape(x, (n, m)) + return (x0, (n + 1, m)) + + _check_save_roundtrip(InputModule) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py new file mode 100644 index 0000000000..0a3272bfe8 --- /dev/null +++ b/tests/python/relax/test_transform.py @@ -0,0 +1,611 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations # must import to defer parsing of annotations +import pytest +import tvm +from tvm import relax +from tvm import tir +from tvm.ir import structural_equal +from tvm.ir.base import assert_structural_equal +from tvm.ir.module import IRModule + +import tvm.script +from tvm.script import tir as T, relax as R + + +def test_fma_rewrite(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: Tensor((m, n), "float32"), y: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.multiply(x, y) + gv0 = relax.add(lv0, y) + relax.output(gv0) + gv1 = relax.multiply(x, y) + gv2 = relax.add(gv1, y) + return (gv0, gv1, gv2) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor((m, n), "float32"), y: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.multiply(x, y) + gv0 = relax.ewise_fma(x, y, y) + relax.output(gv0) + gv1 = relax.multiply(x, y) + gv2 = relax.add(gv1, y) + return (gv0, gv1, gv2) + + After = relax.transform.RewriteFMA()(Before) + + assert_structural_equal(After, Expected) + + +def test_fma_rewrite_python(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: Tensor((m, n), "float32"), y: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.multiply(x, y) + gv0 = relax.add(lv0, y) + relax.output(gv0) + gv1 = relax.multiply(x, y) + gv2 = relax.add(gv1, y) + return (gv0, gv1, gv2) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor((m, n), "float32"), y: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.multiply(x, y) + gv0 = relax.ewise_fma(x, y, y) + relax.output(gv0) + gv1 = relax.multiply(x, y) + gv2 = relax.add(gv1, y) + return (gv0, gv1, gv2) + + After = relax.transform.EwiseRewriteFMA()(Before) + + assert_structural_equal(After, Expected) + + +def test_fma_fuse(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): + with relax.dataflow(): + lv0 = relax.multiply(x, y) + gv0 = relax.add(lv0, y) + relax.output(gv0) + return gv0 + + After = relax.transform.FuseFMA()(Before) + + # TODO(@yuchen): add assert_structural_equal after parser allows CallNode as Function's body + assert len(After.get_global_vars()) == 2 + main = After["main"] + ewise_fma_fused = After["ewise_fma_fused"] + + # check sub function call type inference + assert_structural_equal(ewise_fma_fused.body.checked_type, relax.DynTensorType(2, "float32")) + sub_func_call = main.body.blocks[0].bindings[1].value + sub_func_call_var = main.body.blocks[0].bindings[1].var + assert_structural_equal(sub_func_call.checked_type, relax.DynTensorType(2, "float32")) + assert_structural_equal(sub_func_call_var.checked_type, relax.DynTensorType(2, "float32")) + + # check sub function call shape inference + assert isinstance(ewise_fma_fused.body.shape, relax.ShapeExpr) + assert ewise_fma_fused.body.shape.values[0] == 3 + assert ewise_fma_fused.body.shape.values[1] == 4 + assert sub_func_call.shape.values[0] == 3 + assert sub_func_call.shape.values[1] == 4 + assert sub_func_call_var.shape.values[0] == 3 + assert sub_func_call_var.shape.values[1] == 4 + + +def test_fma_fuse_python(): + @tvm.script.ir_module + class Before: + @R.function + def main(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): + with relax.dataflow(): + lv0 = relax.multiply(x, y) + gv0 = relax.add(lv0, y) + relax.output(gv0) + return gv0 + + After = relax.transform.EwiseFuseFMA()(Before) + + # TODO(@yuchen): add assert_structural_equal after parser allows CallNode as Function's body + assert len(After.get_global_vars()) == 2 + main = After["main"] + ewise_fma_fused = After["ewise_fma_fused"] + + # check sub function call type inference + assert_structural_equal(ewise_fma_fused.body.checked_type, relax.DynTensorType(2, "float32")) + sub_func_call = main.body.blocks[0].bindings[1].value + sub_func_call_var = main.body.blocks[0].bindings[1].var + assert_structural_equal(sub_func_call.checked_type, relax.DynTensorType(2, "float32")) + assert_structural_equal(sub_func_call_var.checked_type, relax.DynTensorType(2, "float32")) + + # check sub function call shape inference + assert isinstance(ewise_fma_fused.body.shape, relax.ShapeExpr) + assert ewise_fma_fused.body.shape.values[0] == 3 + assert ewise_fma_fused.body.shape.values[1] == 4 + assert sub_func_call.shape.values[0] == 3 + assert sub_func_call.shape.values[1] == 4 + assert sub_func_call_var.shape.values[0] == 3 + assert sub_func_call_var.shape.values[1] == 4 + + +def test_dataflowpass_fail(): + # raise error on rewriting/removing existing Global Vars inside the dataflow block. + with pytest.raises(tvm.TVMError): + + @tvm.script.ir_module + class TestRemoveGlobalScopeVar: + @R.function + def main(x: Tensor(_, "float32"), y: Tensor(_, "float32")): + with relax.dataflow(): + gv_remove = relax.add(x, y) + gv1 = relax.add(x, y) + relax.output(gv_remove, gv1) + return (gv_remove, gv1) + + relax.transform.FailTestRewrite()(TestRemoveGlobalScopeVar) + + with pytest.raises(tvm.TVMError): + + @tvm.script.ir_module + class TestRewriteGlobalScopeVar: + @R.function + def main(x: Tensor(_, "float32"), y: Tensor(_, "float32")): + with relax.dataflow(): + gv_rewrite = relax.add(x, y) + gv1 = relax.add(x, y) + relax.output(gv_rewrite, gv1) + return (gv_rewrite, gv1) + + relax.transform.FailTestRewrite()(TestRewriteGlobalScopeVar) + + # raise error on rewriting/removing existing Symbolic Vars inside the dataflow block + # check all Symbolic Vars defined in R.match_shape + with pytest.raises(tvm.TVMError): + + @tvm.script.ir_module + class TestRewriteSymbolicVar: + @R.function + def main(x: Tensor(_, "float32"), y: Tensor(_, "float32")): + with relax.dataflow(): + lv0 = R.match_shape(x, (m, n)) + gv0 = relax.add(lv0, y) + relax.output(gv0) + return gv0 + + relax.transform.FailTestRewrite()(TestRewriteSymbolicVar) + + with pytest.raises(tvm.TVMError): + + @tvm.script.ir_module + class TestRemoveSymbolicVar: + @R.function + def main(x: Tensor(_, "float32"), y: Tensor(_, "float32")): + with relax.dataflow(): + lv0 = R.match_shape(x, (m, n, d)) + gv0 = relax.add(lv0, y) + relax.output(gv0) + return gv0 + + relax.transform.FailTestRewrite()(TestRemoveSymbolicVar) + + +def test_visit_shape(): + @tvm.script.ir_module + class TestVisitShape: + @R.function + def foo(x: Tensor((m, n), "float32")): + gv0 = R.add(x, x) + return gv0 + + mod = TestVisitShape + + shape_expr = [] + + def fvisit(e): + if isinstance(e, relax.ShapeExpr): + nonlocal shape_expr + shape_expr.append(e) + + relax.analysis.post_order_visit(mod["foo"], fvisit) + + # should have visited ShapeExpr 3 times + # the first time being visited is x.shape + # the last two times are the call node's shape and gv0's shape + assert len(shape_expr) == 3 + assert shape_expr[0] == mod["foo"].params[0].shape + assert shape_expr[1] == shape_expr[2] + + +def test_to_non_dataflow(): + @tvm.script.ir_module + class TestToNonDataflow: + @R.function + def foo(x: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.call_tir("test.op.identity", (x,), (m, n), dtype="float32") + gv0 = relax.call_tir("test.op.identity", (lv0,), (m, n), dtype="float32") + relax.output(gv0) + return gv0 + + mod = TestToNonDataflow + + old_vars = [] + + def fvisit(e): + if isinstance(e, relax.Var): + nonlocal old_vars + old_vars.append(e) + + relax.analysis.post_order_visit(mod["foo"], fvisit) + x, lv0, gv0 = old_vars + + new_mod = relax.transform.ToNonDataflow()(mod) + + new_vars = [] + + def fvisit(e): + if isinstance(e, relax.Var): + nonlocal new_vars + new_vars.append(e) + + relax.analysis.post_order_visit(new_mod["foo"], fvisit) + + assert x == new_vars[0] + assert lv0 != new_vars[1] + assert isinstance(lv0, relax.DataflowVar) + assert not isinstance(new_vars[1], relax.DataflowVar) + + assert isinstance(gv0, relax.Var) + assert isinstance(new_vars[2], relax.Var) + assert gv0 == new_vars[2] + + +def test_call_tir_rewrite(): + @tvm.script.ir_module + class TestCallTIRRewrite: + @R.function + def foo(x: Tensor((m, n), "float32")): + gv0 = relax.call_tir("test.op.identity", (x,), (m, n), dtype="float32") + return gv0 + + mod = TestCallTIRRewrite + + # before rewrite + v0 = mod["foo"].body.blocks[0].bindings[0].var + s0 = mod["foo"].body.blocks[0].bindings[0].value + assert isinstance(s0, tvm.relay.Call) + assert s0.op.name == "relax.call_tir" + + # after rewrite + new_mod = relax.transform.CallTIRRewrite()(mod) + func = new_mod["foo"] + + block = func.body.blocks[0] + assert not isinstance(block, relax.DataflowBlock) + + s1 = block.bindings[0].value + assert isinstance(s1, tvm.relay.Call) + assert s1.op.name == "relax.builtin.alloc_tensor" + assert isinstance(s1.args[0], relax.ShapeExpr) + assert structural_equal(s1.args[0], s0.args[2]) + s2 = block.bindings[1].value + assert s2.op.global_symbol == "test.op.identity" + + +def test_vm_memory_lower(): + @tvm.script.ir_module + class TestVMMemoryLower: + @R.function + def foo(x: Tensor((m, n), "float32")) -> Tensor: + alloc = relax.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32") + _ = relax.call_packed( + "test.op.identity", x, alloc, type_args=(Tensor(rank=2, dtype="float32")) + ) + gv0 = alloc + return gv0 + + mod = TestVMMemoryLower + + # after vm memory lowering + new_mod = relax.transform.VMMemoryLower()(mod) + func = new_mod["foo"] + + assert isinstance(new_mod, tvm.IRModule) + assert isinstance(func, tvm.relax.expr.Function) + + block = func.body.blocks[0] + s1 = block.bindings[0].value + assert isinstance(s1, tvm.relay.Call) + assert s1.op.name == "relax.vm.builtin.alloc_storage" + s2 = block.bindings[1].value + assert isinstance(s2, tvm.relay.Call) + s4 = block.bindings[3].value + assert isinstance(s4, tvm.relay.Call) + assert isinstance(s4.op, relax.ExternFunc) + assert s4.op.global_symbol == "test.op.identity" + + +def test_vm_shape_lowering(): + @tvm.script.ir_module + class TestVMShapeLower: + @R.function + def foo(x: Tensor(_, "float32")): + relax.match_shape(x, (n, m)) + return (n * 2, m * 3) + + mod = TestVMShapeLower + + # after vm shape lowering + new_mod = relax.transform.VMShapeLower()(mod) + + assert isinstance(new_mod, tvm.IRModule) + assert isinstance(new_mod["shape_func"], tvm.tir.function.PrimFunc) + func = new_mod["foo"] + assert isinstance(func, tvm.relax.expr.Function) + + s1 = func.body.blocks[0].bindings[0].value + assert isinstance(s1.op, relax.ExternFunc) + assert s1.op.global_symbol == "vm.builtin.alloc_shape_heap" + assert s1.args[0].values[0] == 4 + s2 = func.body.blocks[1].bindings[0].value + assert isinstance(s2.op, relax.ExternFunc) + assert s2.op.global_symbol == "vm.builtin.shape_of" + s3 = func.body.blocks[1].bindings[1].value + assert isinstance(s3, tvm.relay.Call) + assert s3.op.name == "relax.vm.builtin.store_shape" + s4 = func.body.blocks[2].bindings[0].value + assert isinstance(s4.op, relax.GlobalVar) + assert s4.op.name_hint == "shape_func" + s5 = func.body.blocks[2].bindings[1].value + assert isinstance(s5, tvm.relay.Call) + assert s5.op.name == "relax.vm.builtin.load_shape" + + +def test_vm_static_shape_lowering(): + @tvm.script.ir_module + class TestVMStaticShapeLower: + @R.function + def foo(x: Tensor((2, 3), "float32")): + with relax.dataflow(): + y = R.call_tir("test.vm.tile", (x), (2, 6), dtype="float32") + relax.output(y) + return y + + mod = TestVMStaticShapeLower + + # after vm shape lowering + new_mod = relax.transform.VMShapeLower()(mod) + + # before and after programs should be structurally equal + # since the program only has static shapes + assert_structural_equal(mod, new_mod) + + +def test_vm_shape_lowering_func_param_with_shape(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def foo(x: Tensor((m, n), "float32"), w: Tensor((n, k), "float32")): + gv0 = R.call_tir(tir_matmul, (x, w), (m, k), dtype="float32") + return gv0 + + mod = InputModule + + # after vm shape lowering + new_mod = relax.transform.VMShapeLower()(mod) + + assert isinstance(new_mod, tvm.IRModule) + assert isinstance(new_mod["shape_func"], tvm.tir.function.PrimFunc) + assert isinstance(new_mod["tir_matmul"], tvm.tir.function.PrimFunc) + func = new_mod["foo"] + assert isinstance(func, tvm.relax.expr.Function) + + x, w = func.params + s1 = func.body.blocks[0].bindings[0].value + assert isinstance(s1.op, relax.ExternFunc) + assert s1.op.global_symbol == "vm.builtin.alloc_shape_heap" + assert s1.args[0].values[0] == 3 + + s2 = func.body.blocks[0].bindings[1].value + assert isinstance(s2.op, relax.ExternFunc) + assert s2.op.global_symbol == "vm.builtin.shape_of" + assert s2.args[0] == x + s3 = func.body.blocks[0].bindings[2].value + assert isinstance(s3, tvm.relay.Call) + assert s3.op.name == "relax.vm.builtin.store_shape" + + s4 = func.body.blocks[0].bindings[3].value + assert isinstance(s4.op, relax.ExternFunc) + assert s4.op.global_symbol == "vm.builtin.shape_of" + assert s4.args[0] == w + s5 = func.body.blocks[0].bindings[2].value + assert isinstance(s5, tvm.relay.Call) + assert s5.op.name == "relax.vm.builtin.store_shape" + + +def test_vm_shape_lower_int32_shape(): + @tvm.script.ir_module + class InputModule: + @R.function + def foo(x: Tensor((d,), "float32")): + gv0 = R.call_tir("my_extern", (x,), (tir.cast("int32", d),), dtype="float32") + return gv0 + + before_mod = InputModule + after_mod = relax.transform.VMShapeLower()(before_mod) + + assert isinstance(after_mod, tvm.IRModule) + shape_func = after_mod["shape_func"] + assert isinstance(shape_func, tvm.tir.function.PrimFunc) + buffer_store_stmt = shape_func.body[0] + assert isinstance(buffer_store_stmt, tvm.tir.stmt.BufferStore) + # verify that the value in BufferStore stmt is cast to int64 first + cast_expr = buffer_store_stmt.value + assert isinstance(cast_expr, tvm.tir.expr.Cast) + assert cast_expr.dtype == "int64" + + +def test_normalize_function(): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + type_anno = relax.DynTensorType(ndim=2, dtype="float16") + x = relax.Var("x", [m, n], type_anno) + + # Note: the parser automatically normalize the IR written in TVMScript, + # so we manually construct the function here. + mul_add = relax.Function( + [x], + relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), + ret_type=type_anno, + ret_shape=relax.RuntimeDepShape(), + ) + mul_add = mul_add.with_attr("global_symbol", "mul_add") + before_mod = tvm.IRModule.from_expr(mul_add) + + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def mul_add(x: Tensor((m, n), "float16")) -> Tensor(None, "float16", ndim=2): + gv = R.add(x, x) + gv1 = relax.add(x, x) + return R.multiply(gv, gv1) + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_if(): + cond = relax.Var("cond", [], type_annotation=relax.DynTensorType(0, "bool")) + x = relax.Var("x", [tir.IntImm("int64", 1)], type_annotation=relax.DynTensorType(1, "float32")) + # TODO(relax-team): add type and shape inference for IfNode + y = relax.Var("y") + + # Note: the parser automatically normalize the IR written in TVMScript, + # so we manually construct the function and If here. + f = relax.Function( + [cond, x], + relax.SeqExpr( + [ + relax.BindingBlock( + [ + relax.VarBinding( + y, + relax.If( + cond, + relax.op.multiply(relax.op.add(x, x), relax.op.add(x, x)), + relax.op.add(relax.op.multiply(x, x), relax.op.multiply(x, x)), + ), + ) + ] + ) + ], + y, + ), + ret_type=relax.DynTensorType(1, "float32"), + ret_shape=relax.RuntimeDepShape(), + ) + + f = f.with_attr("global_symbol", "f") + before_mod = tvm.IRModule.from_expr(f) + after_mod = relax.transform.Normalize()(before_mod) + + @tvm.script.ir_module + class Expected: + @R.function + def f( + cond: Tensor((), "bool"), x: Tensor((1,), "float32") + ) -> Tensor(None, "float32", ndim=1): + if cond: + gv = R.add(x, x) + gv1 = R.add(x, x) + y = R.multiply(gv, gv1) + else: + gv = R.multiply(x, x) + gv1 = R.multiply(x, x) + y = R.add(gv, gv1) + return y + + assert_structural_equal(after_mod, Expected) + + +def test_normalize_no_op(): + # the normalize pass should be no-op for IR in ANF + @tvm.script.ir_module + class ANFMod1: + @R.function + def f(x: Tensor(_, "float32")): + gv = relax.add(x, x) + gv1 = relax.add(gv, gv) + gv2 = relax.add(gv, gv1) + return (gv, gv2) + + before_mod = ANFMod1 + after_mod = relax.transform.Normalize()(before_mod) + assert_structural_equal(before_mod, after_mod, map_free_vars=True) + + @tvm.script.ir_module + class ANFMod2: + @R.function + def foo(x: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.call_tir("test.op.identity", (x,), (m, n), dtype="float32") + gv0 = relax.call_tir("test.op.identity", (lv0,), (m, n), dtype="float32") + relax.output(gv0) + return gv0 + + mod = ANFMod2 + mod_post = relax.transform.Normalize()(mod) + + assert_structural_equal(mod, mod_post) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py b/tests/python/relax/test_transform_annotate_tir_op_pattern.py new file mode 100644 index 0000000000..2f1cc6301d --- /dev/null +++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py @@ -0,0 +1,311 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import enum +import pytest +import tvm +from tvm import relax + +import tvm.script +from tvm.script import tir as T + + +class OpPatternKind(enum.IntEnum): + kElemWise = 0 + kBroadcast = 1 + kInjective = 2 + kCommReduce = 3 + kOutEWiseFusable = 4 + kTuple = 7 + kOpaque = 8 + + +def test_annotate_opkind_outewisefusable(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_annotate_opkind_reduce(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def sum(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16,)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SR", [i, j]) + with T.init(): + B[vi] = 0.0 + B[vi] += A[vi, vj] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["sum"].attrs["op_pattern"] == OpPatternKind.kCommReduce + + +def test_annotate_opkind_ewise(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def elemwise(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["elemwise"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_broadcast(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def broadcast(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16, 16, 16)) + + for i0, j0, i1, j1 in T.grid(16, 16, 16, 16): + with T.block("matmul"): + vi0, vj0, vi1, vj1 = T.axis.remap("SSSS", [i0, j0, i1, j1]) + B[vi0, vj0, vi1, vj1] = A[vj0, vj1] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["broadcast"].attrs["op_pattern"] == OpPatternKind.kBroadcast + + +def test_annotate_opkind_injective(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def injective(x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "elemwise"}) + A = T.match_buffer(x, (4, 4, 4, 4)) + B = T.match_buffer(y, (16, 16)) + + for i, j in T.grid(16, 16): + with T.block("matmul"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi // 4, vj // 4, vi % 4, vj % 4] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["injective"].attrs["op_pattern"] == OpPatternKind.kInjective + + +def test_annotate_opkind_bias_add(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_bias_add( + A: T.Buffer[(1, 1000), "float32"], + B: T.Buffer[(1000,), "float32"], + C: T.Buffer[(1, 1000), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True}) + # body + # with T.block("root") + for i0, i1 in T.grid(1, 1000): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax1] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["tir_bias_add"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_add_broadcast_with_unit_shape(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def add_with_unit_dim_len_broadcast( + A: T.Buffer[(1, 64, 112, 112), "float32"], + B: T.Buffer[(64, 1, 1), "float32"], + C: T.Buffer[(1, 64, 112, 112), "float32"], + ) -> None: + T.func_attr({"global_symbol": "add5", "tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(1, 64, 112, 112): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, ax2, ax3], B[ax1, 0, 0]) + T.writes(C[ax0, ax1, ax2, ax3]) + C[ax0, ax1, ax2, ax3] = A[ax0, ax1, ax2, ax3] + B[ax1, 0, 0] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["add_with_unit_dim_len_broadcast"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_add_zero_dim_element_wise(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def add_zero_dim( + A: T.Buffer[(128,), "float32"], + B: T.Buffer[(), "float32"], + C: T.Buffer[(128,), "float32"], + ) -> None: + T.func_attr({"global_symbol": "add8", "tir.noalias": True}) + for i0 in T.serial(128): + with T.block("T_add"): + ax0 = T.axis.spatial(128, i0) + T.reads(A[ax0], B[()]) + T.writes(C[ax0]) + C[ax0] = A[ax0] + B[()] + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["add_zero_dim"].attrs["op_pattern"] == OpPatternKind.kElemWise + + +def test_annotate_opkind_pooling(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def max_pool2d( + rxplaceholder_1: T.Buffer[(1, 64, 112, 112), "float32"], + tensor_1: T.Buffer[(1, 64, 56, 56), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "max_pool2d", "T.noalias": True}) + # body + # with T.block("root") + pad_temp_1 = T.alloc_buffer([1, 64, 114, 114], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 64, 114, 114): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1]) + T.writes(pad_temp_1[ax0, ax1, ax2, ax3]) + pad_temp_1[ax0, ax1, ax2, ax3] = T.if_then_else( + 1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113, + rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1], + T.float32(-3.4028234663852886e38), + dtype="float32", + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 64, 56, 56, 3, 3): + with T.block("tensor"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads( + tensor_1[ax0, ax1, ax2, ax3], + pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1], + ) + T.writes(tensor_1[ax0, ax1, ax2, ax3]) + with T.init(): + tensor_1[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e38) + tensor_1[ax0, ax1, ax2, ax3] = T.max( + tensor_1[ax0, ax1, ax2, ax3], + pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1], + ) + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["max_pool2d"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +def test_annotate_opkind_softmax(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def softmax( + rxplaceholder_1: T.Buffer[(16, 16), "float32"], + T_softmax_norm_1: T.Buffer[(16, 16), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "softmax", "T.noalias": True}) + # body + # with T.block("root") + T_softmax_maxelem_1 = T.alloc_buffer([16], dtype="float32") + T_softmax_exp_1 = T.alloc_buffer([16, 16], dtype="float32") + T_softmax_expsum_1 = T.alloc_buffer([16], dtype="float32") + for i0_7, i1_3 in T.grid(16, 16): + with T.block("T_softmax_maxelem"): + i0_8, k = T.axis.remap("SR", [i0_7, i1_3]) + T.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k]) + T.writes(T_softmax_maxelem_1[i0_8]) + with T.init(): + T_softmax_maxelem_1[i0_8] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem_1[i0_8] = T.max( + T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k] + ) + for i0_9, i1_4 in T.grid(16, 16): + with T.block("T_softmax_exp"): + i0_10, i1_5 = T.axis.remap("SS", [i0_9, i1_4]) + T.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10]) + T.writes(T_softmax_exp_1[i0_10, i1_5]) + T_softmax_exp_1[i0_10, i1_5] = T.exp( + rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype="float32" + ) + for i0_11, i1_6 in T.grid(16, 16): + with T.block("T_softmax_expsum"): + i0_12, k = T.axis.remap("SR", [i0_11, i1_6]) + T.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12, k]) + T.writes(T_softmax_expsum_1[i0_12]) + with T.init(): + T_softmax_expsum_1[i0_12] = T.float32(0) + T_softmax_expsum_1[i0_12] = ( + T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k] + ) + for i0_13, i1_7 in T.grid(16, 16): + with T.block("T_softmax_norm"): + i0_14, i1_8 = T.axis.remap("SS", [i0_13, i1_7]) + T.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14]) + T.writes(T_softmax_norm_1[i0_14, i1_8]) + T.block_attr({"axis": 1}) + T_softmax_norm_1[i0_14, i1_8] = ( + T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14] + ) + + mod = InputModule + new_mod = relax.transform.AnnotateTIROpPattern()(mod) + assert new_mod["softmax"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py new file mode 100644 index 0000000000..3826e44afe --- /dev/null +++ b/tests/python/relax/test_transform_bind_params.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations # must import to defer parsing of annotations +import sys +import pytest + +import tvm +import tvm.testing +from tvm import relax +import numpy as np + +import tvm.script +from tvm.script import tir as T, relax as R + +use_np_array = tvm.testing.parameter(False, True) + + +def test_bind_params(use_np_array): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + A = T.match_buffer(x, (16, 16)) + B = T.match_buffer(y, (16, 16)) + C = T.match_buffer(z, (16, 16)) + for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): + with T.block("matmul"): + vi = T.axis.S(16, i0 * 4 + i1) + vj = T.axis.S(16, j) + vk = T.axis.R(16, k0 * 4 + k1) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def main( + x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + gv0 = R.call_tir(tir_matmul, (x, w), (16, 16), dtype="float32") + return gv0 + + x_np = np.random.rand(16, 16).astype(np.float32) + w_np = np.random.rand(16, 16).astype(np.float32) + x_tvm = tvm.nd.array(x_np) + w_tvm = tvm.nd.array(w_np) + params_dict = {"w": w_np if use_np_array else w_tvm} + mod = relax.transform.BindParams("main", params_dict)(InputModule) + assert len(mod["main"].params) == 1 + + target = tvm.target.Target("llvm") + ex_after = relax.vm.build(mod, target) + vm_after = relax.VirtualMachine(ex_after, tvm.cpu()) + res_after = vm_after["main"](x_tvm) + + ex_before = relax.vm.build(InputModule, target) + vm_before = relax.VirtualMachine(ex_before, tvm.cpu()) + res_before = vm_before["main"](x_tvm, w_tvm) + + tvm.testing.assert_allclose(res_before.numpy(), res_after.numpy()) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py new file mode 100644 index 0000000000..42b7f7381e --- /dev/null +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -0,0 +1,246 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations # must import to defer parsing of annotations +import pytest +import tvm +from tvm import relax +from tvm.ir.base import assert_structural_equal + +from tvm.runtime.object import Object + +import tvm.script +from tvm.script import relax as R + + +def test_simple_assignments(): + @tvm.script.ir_module + class TestChainAssignments: + @R.function + def main(x: Tensor): + y = x + z = y + q = z + p = q + o = p + return o + + # a little annoying to have these unused bindings around + # but they can be eliminated in a separate pass + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor): + y = x + z = x + q = x + p = x + o = x + return x + + new_mod = relax.transform.CanonicalizeBindings()(TestChainAssignments) + assert_structural_equal(new_mod, Expected) + + +def test_dataflow_block(): + @tvm.script.ir_module + class TestDataflowAssignments: + @R.function + def main(x: Tensor): + with R.dataflow(): + y = relax.const(1) + z = y + o = z + p = o + m = p + n = m + R.output(n) + return n + + # a little annoying to have these unused bindings around + # but they can be eliminated in a separate pass + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor): + with R.dataflow(): + y = relax.const(1) + z = y + o = y + p = y + m = y + # we can't get rid of n because it leaves the block + n = y + R.output(n) + return n + + new_mod = relax.transform.CanonicalizeBindings()(TestDataflowAssignments) + assert_structural_equal(new_mod, Expected) + + +def test_ops(): + @tvm.script.ir_module + class TestOps: + @R.function + def main(x: Tensor, y: Tensor): + w = y + q = x + z = relax.add(w, q) + return relax.add(q, z) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor, y: Tensor): + w = y + q = x + z = relax.add(y, x) + return relax.add(x, z) + + new_mod = relax.transform.CanonicalizeBindings()(TestOps) + assert_structural_equal(new_mod, Expected) + + +def test_casting(): + @tvm.script.ir_module + class TestCasting: + @R.function + def main(x: Tensor) -> Object: + y = x + # z will be treated as object type even though it's a tensor + z: Object = y + return z + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor) -> Object: + y = x + # Cannot unify because the cast indicates user intent + z: Object = x + return z + + new_mod = relax.transform.CanonicalizeBindings()(TestCasting) + assert_structural_equal(new_mod, Expected) + + +def test_match_shape(): + @tvm.script.ir_module + class TestMatchShape: + @R.function + def main(x: Tensor): + q = x + z = R.match_shape(q, (m, n)) + w = z + return w + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor): + q = x + # can't get rid of z because its shape_ is different from x's + z = R.match_shape(x, (m, n)) + w = z + return z + + new_mod = relax.transform.CanonicalizeBindings()(TestMatchShape) + assert_structural_equal(new_mod, Expected) + + +def test_same_shape(): + @tvm.script.ir_module + class TestSameShape: + @R.function + def main(x: Tensor((m, n), _)): + y = x + # trivial check + z = R.match_shape(x, (m, n)) + w = z + q = relax.add(w, y) + return relax.add(q, w) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor((m, n), _)): + y = x + # canonicalized into a var binding + z = x + w = x + q = relax.add(x, x) + return relax.add(q, x) + + new_mod = relax.transform.CanonicalizeBindings()(TestSameShape) + assert_structural_equal(new_mod, Expected) + + +def test_change_shape(): + @tvm.script.ir_module + class TestChangeShape: + @R.function + def main(x: Tensor((m, n), _)): + y = x + # not trivial: introduces new shape vars + z = R.match_shape(x, (o, p)) + w = z + q = relax.add(w, y) + return relax.add(q, w) + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor((m, n), _)): + y = x + z = R.match_shape(x, (o, p)) + w = z + # the shape_ field on q will need to be updated + q = relax.add(z, x) + return relax.add(q, z) + + new_mod = relax.transform.CanonicalizeBindings()(TestChangeShape) + assert_structural_equal(new_mod, Expected) + + +def test_unbound_match_shape(): + # ensure that match shapes that do not bind a var are handled correctly + @tvm.script.ir_module + class TestUnboundMatchShape: + @R.function + def main(x: Tensor): + y = x + z = y + R.match_shape(z, (m, n)) + w = z + return w + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor): + y = x + z = x + R.match_shape(x, (m, n)) + w = x + return x + + new_mod = relax.transform.CanonicalizeBindings()(TestUnboundMatchShape) + assert_structural_equal(new_mod, Expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py new file mode 100644 index 0000000000..55cb27bb28 --- /dev/null +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +import pytest +import os +import tvm +import tvm.testing +from tvm import relax +import numpy as np +from tvm.script import relax as R +from tvm.relax.testing import transform +import tempfile +from tvm.relax.transform.tuning_api import Trace +from tvm import meta_schedule as ms + +env_checker_codegen = tvm.get_global_func("relax.ext.tensorrt", True) +env_checker_runtime = tvm.get_global_func("relax.is_tensorrt_runtime_enabled", True) + +has_tensorrt_codegen = pytest.mark.skipif( + not env_checker_codegen, + reason="TensorRT codegen not available", +) +has_tensorrt_runtime = pytest.mark.skipif( + not env_checker_runtime or not env_checker_runtime(), + reason="TensorRT runtime not available", +) + +# Global variable in pytest that applies markers to all tests. +pytestmark = [has_tensorrt_codegen, has_tensorrt_runtime] + +# Target gpu +target_str = "nvidia/nvidia-t4" +target = tvm.target.Target(target_str) +dev = tvm.cuda() + + +def check_executable(exec, dev, inputs, expected): + vm = relax.VirtualMachine(exec, dev) + out = vm["main"](*inputs) + tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5, rtol=1e-5) + + +def check_roundtrip(exec0, dev, inputs, expected): + exec0.mod.export_library("exec.so") + exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so")) + os.remove("exec.so") + assert exec0.stats() == exec1.stats() + assert exec0.as_text() == exec1.as_text() + + check_executable(exec0, dev, inputs, expected) + check_executable(exec1, dev, inputs, expected) + + +def gen_ground_truth(mod, target, dev, inputs): + # Lower and run tuning + # Since there is no default schedule for GPU in MS yet, this is necessary + with tempfile.TemporaryDirectory() as work_dir: + with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0): + seq = tvm.transform.Sequential( + [ + transform.LowerWithRelayOpStrategyPass(target), + relax.transform.MetaScheduleTuneIRMod( + params={}, work_dir=work_dir, max_trials_global=8 + ), + relax.transform.MetaScheduleApplyDatabase(work_dir), + ] + ) + new_mod = seq(mod) + assert relax.analysis.well_formed(new_mod) + exec = relax.vm.build(new_mod, target, params={}) + vm = relax.VirtualMachine(exec, dev) + return vm["main"](*inputs) + + +@tvm.testing.requires_gpu +def test_single_annot_func(): + @tvm.script.ir_module + class InputModule: + @R.function + def relax_func( + x: Tensor((16, 16), "float32"), y: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + z1 = relax.multiply(x, y) + z2 = relax.add(z1, z1) + z3 = relax.add(z1, z2) + return z3 + + @R.function + def main( + x: Tensor((16, 16), "float32"), y: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + lv0: Tensor((16, 16), "float32") = relax_func(x, y) + return lv0 + + # Prepare IRModule and its input + mod = InputModule + assert isinstance(mod, tvm.IRModule) + + np0 = np.random.rand(16, 16).astype(np.float32) + np1 = np.random.rand(16, 16).astype(np.float32) + data0 = tvm.nd.array(np0, dev) + data1 = tvm.nd.array(np1, dev) + inputs = [data0, data1] + + # Ground truth should be generated before annotation + # due to the conflict with MS task extraction + # TODO(@sunggg): Sort this out + expected = gen_ground_truth(mod, target, dev, inputs) + + # TODO(@sunggg): Revisit when TVMScript supports annotation. + # Annotate target function. + new_relax_func = mod["relax_func"].with_attr("Codegen", "tensorrt") + new_relax_func = new_relax_func.with_attr("global_symbol", "trt_relax_func") + mod["relax_func"] = new_relax_func + + # Run Codegen pass + seq = tvm.transform.Sequential( + [relax.transform.RunCodegen(), relax.transform.RemoveUnusedFunctions()] + ) + + new_mod = seq(mod) + ex0 = relax.vm.build(new_mod, target, params={}) + + # Sanity check for the correctness and rountrip + check_roundtrip(ex0, dev, inputs, expected) + + # If the annotation does not match with the target codegen, do not perform the codegen process. + new_mod = relax.transform.RunCodegen(target_codegens=["INVALID_CODEGEN"])(mod) + # TODO(tvm-team): Currently disabled due to the lack of type annotation support during parser. + # Revisit when new version of parser is available. + # tvm.ir.assert_structural_equal(mod, new_mod) + + +@tvm.testing.requires_gpu +def test_mix_use_tensorrt_and_tvm(): + @tvm.script.ir_module + class InputModule: + @R.function + def byoc_func( + x: Tensor((16, 16), "float32"), y: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + z1 = relax.multiply(x, y) + z2 = relax.add(z1, z1) + z3 = relax.add(z1, z2) + return z3 + + @R.function + def tvm_func( + x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + gv0 = R.multiply(x, w) + gv1 = R.add(x, gv0) + return gv1 + + @R.function + def main( + x: Tensor((16, 16), "float32"), y: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + lv0 = byoc_func(x, y) + lv1 = tvm_func(x, lv0) + return lv1 + + # Prepare IRModule and its inputs + mod = InputModule + assert isinstance(mod, tvm.IRModule) + + np0 = np.random.rand(16, 16).astype(np.float32) + np1 = np.random.rand(16, 16).astype(np.float32) + data0 = tvm.nd.array(np0, dev) + data1 = tvm.nd.array(np1, dev) + inputs = [data0, data1] + expected = gen_ground_truth(mod, target, dev, [data0, data1]) + + # TODO(@sunggg): Revisit when TVMScript supports annotation. + # Annotate target function. + new_byoc_func = mod["byoc_func"].with_attr("Codegen", "tensorrt") + new_byoc_func = new_byoc_func.with_attr("global_symbol", "trt_byoc_func") + mod["byoc_func"] = new_byoc_func + + # Run Codegen pass + with tempfile.TemporaryDirectory() as work_dir: + with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=3): + seq = tvm.transform.Sequential( + [ + relax.transform.RunCodegen(), + relax.transform.RemoveUnusedFunctions(), + transform.LowerWithRelayOpStrategyPass(target), + relax.transform.MetaScheduleTuneIRMod( + params={}, work_dir=work_dir, max_trials_global=8 + ), + relax.transform.MetaScheduleApplyDatabase(work_dir), + ] + ) + new_mod = seq(mod) + assert relax.analysis.well_formed(new_mod) + with transform.PassContext(opt_level=0): + ex0 = relax.vm.build(new_mod, target, params={}) + + # Sanity check for the correctness and rountrip + check_roundtrip(ex0, dev, inputs, expected) + + +# TODO(@sunggg): test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding) + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py new file mode 100644 index 0000000000..d00979cdfc --- /dev/null +++ b/tests/python/relax/test_transform_fold_constant.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations # must import to defer parsing of annotations +import pytest +import sys +import tvm +import tvm.testing +from tvm import relax +from tvm.ir.base import assert_structural_equal +import numpy as np + +import tvm.script +from tvm.script import tir as T, relax as R + + +def gen_mod(mod, name, binding): + """Select relax function with name, rename to main and and bind constant. + + Parameters + ---------- + mod: IRModule + The input module + + name: str + The name of relax function to preserve and rename to main + + binding: Dict[str, array] + The const parameter bindings + """ + funcs = {} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + + for k, v in mod.functions.items(): + if isinstance(v, tvm.relax.Function): + if k.name_hint == name: + # rename to main + gv = tvm.ir.GlobalVar("main") + funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_type, v.ret_shape).with_attr( + "global_symbol", "main" + ) + else: + funcs[k] = v + mod = tvm.IRModule(funcs) + return relax.transform.BindParams("main", binding)(mod) + + +def test_one_fold_addone(): + # put before after in a single module + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> None: + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @R.function + def before(c0: Tensor((16, 16), "float32")): + lv0 = relax.call_tir(addone, (c0,), (16, 16), dtype="float32") + return lv0 + + @R.function + def expected(c1: Tensor((16, 16), "float32")): + lv0 = c1 + return c1 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_one_fold_transpose(): + # put before after in a single module + @tvm.script.ir_module + class Module: + @T.prim_func + def func(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]) -> None: + for i, j in T.grid(3, 2): + with T.block("transpose"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vj, vi] + + @R.function + def before(c0: Tensor((2, 3), "float32")): + lv0 = relax.call_tir(func, (c0,), (3, 2), dtype="float32") + return lv0 + + @R.function + def expected(c1: Tensor((3, 2), "float32")): + lv0 = c1 + return c1 + + c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3) + c1_np = c0_np.T + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_two_hop_addone(): + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer[(2, 2), "float32"], B: T.Buffer[(2, 2), "float32"]) -> None: + for i, j in T.grid(2, 2): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @R.function + def before(c0: Tensor((2, 2), "float32")): + lv0 = relax.call_tir(addone, (c0,), (2, 2), dtype="float32") + lv1 = relax.call_tir(addone, (lv0,), (2, 2), dtype="float32") + return lv1 + + @R.function + def expected(c1: Tensor((2, 2), "float32"), c2: Tensor((2, 2), "float32")): + lv0 = c1 + lv1 = c2 + return c2 + + c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2) + c1_np = c0_np + 1 + c2_np = c1_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np, "c2": c2_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_dataflow_fold(): + @tvm.script.ir_module + class Module: + @T.prim_func + def identity(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"]) -> None: + for i, j in T.grid(16, 16): + with T.block("identity"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + + @R.function + def before(c0: Tensor((16, 16), "float32")): + with R.dataflow(): + gv0 = relax.call_tir(identity, (c0,), (16, 16), dtype="float32") + R.output(gv0) + return gv0 + + @R.function + def expected(c1: Tensor((16, 16), "float32")): + with R.dataflow(): + gv0 = c1 + R.output(gv0) + return c1 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_fold_mixed_case(): + @tvm.script.ir_module + class Module: + # TIR function can handle different cases. + @T.prim_func + def addone(a: T.handle, b: T.handle) -> None: + n = T.var("int32") + m = T.var("int32") + A = T.match_buffer(a, (n, m)) + B = T.match_buffer(b, (n, m)) + for i, j in T.grid(n, m): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + @T.prim_func + def sub( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("sub"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] - B[vi, vj] + + @R.function + def before(c0: Tensor((16, 16), "float32"), x: Tensor((_, _), "float32")): + x0 = R.match_shape(x, (n, m)) + # this line cannot be folded because n is unknown + lv0 = relax.call_tir(addone, (c0,), (n, 16), dtype="float32") + # this line can be folded + lv1 = relax.call_tir(addone, (c0,), (16, 16), dtype="float32") + # this line can be folded because all inputs are const + lv2 = relax.call_tir(sub, (c0, lv1), (16, 16), dtype="float32") + # this line can not be folded because x's shape is unknown + lv3 = relax.call_tir(sub, (lv2, x), (16, 16), dtype="float32") + return lv3 + + @R.function + def expected( + c0: Tensor((16, 16), "float32"), + c1: Tensor((16, 16), "float32"), + c2: Tensor((16, 16), "float32"), + x: Tensor((_, _), "float32"), + ) -> Tensor: + x0 = R.match_shape(x, (n, m)) + # this line cannot be folded because n is unknown + lv0 = relax.call_tir(addone, (c0,), (n, 16), dtype="float32") + # this line can be folded + lv1 = c1 + # this line can be folded because all inputs are const + lv2 = c2 + # this line can not be folded because x's shape is unknown + lv3 = relax.call_tir(sub, (c2, x), (16, 16), dtype="float32") + return lv3 + + c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) + c1_np = c0_np + 1 + c2_np = c0_np - c1_np + + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c0": c0_np, "c1": c1_np, "c2": c2_np}) + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +def test_int32_fold(): + @tvm.script.ir_module + class Module: + @T.prim_func + def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), "int32"]) -> None: + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def before(c0: Tensor((16, 16), "int32")): + lv0 = relax.call_tir(addone, (c0,), (16, 16), dtype="int32") + return lv0 + + @R.function + def expected(c1: Tensor((16, 16), "int32")): + lv0 = c1 + return c1 + + c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(Module, "before", {"c0": c0_np}) + expected = gen_mod(Module, "expected", {"c1": c1_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py new file mode 100644 index 0000000000..e02f20de8e --- /dev/null +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -0,0 +1,777 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys + +import pytest +import tvm +from tvm import relax, topi + + +def _check(mod_actual, mod_expected): + mod_actual = relax.transform.AnnotateTIROpPattern()(mod_actual) + mod_actual = relax.transform.FuseOps()(mod_actual) + mod_expected = relax.transform.AnnotateTIROpPattern()(mod_expected) + tvm.ir.assert_structural_equal(mod_actual, mod_expected) + + +def test_fuse_simple(): + """Simple testcase.""" + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + p0 = relax.Var("p0", (), relax.DynTensorType(0, "float32")) + + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call(fused_add_exp_squeeze, [x, relax.const(1, "float32")]) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_conv2d_fuse(): + """Test fusion case of conv2d""" + + def before(dtype): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, dtype) + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w1 = relax.Var("w1", (16, 16, 3, 3), tensor_type) + w2 = relax.Var("w2", (16, 16, 1, 1), tensor_type) + w3 = relax.Var("w3", (16, 16, 3, 3), tensor_type) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit_te(topi.nn.conv2d, lv0, w1, strides=1, padding=1, dilation=1) + # this is the next dominator. + lv2 = bb.emit_te(topi.add, relax.const(1, dtype), lv1) + lv3 = bb.emit_te(topi.add, lv1, lv2) + # second path + lv4 = bb.emit_te(topi.nn.conv2d, lv3, w2, strides=1, padding=0, dilation=1) + lv5 = bb.emit_te(topi.nn.conv2d, lv3, w3, strides=1, padding=1, dilation=1) + gv = bb.emit_output(bb.call_te(topi.add, lv4, lv5)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dtype): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, dtype) + + # Grouped function 1 + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w = relax.Var("w", (16, 16, 3, 3), tensor_type) + p0 = relax.Var("p0", (), relax.DynTensorType(0, dtype)) + with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1") + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w = relax.Var("w", (16, 16, 1, 1), tensor_type) + y = relax.Var("y", (1, 16, 64, 64), tensor_type) + with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=0, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, y, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2") + fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2") + + # Main function + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w1 = relax.Var("w1", (16, 16, 3, 3), tensor_type) + w2 = relax.Var("w2", (16, 16, 1, 1), tensor_type) + w3 = relax.Var("w3", (16, 16, 3, 3), tensor_type) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, relax.const(1, dtype)])) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2, lv2])) + bb.emit_func_output(gv) + + return bb.get() + + _check(before("float32"), expected("float32")) + _check(before("float16"), expected("float16")) + _check(before("int8"), expected("int8")) + + +def test_concatenate(): + """Test fusion case involving concat op and Tuple node""" + + def before(): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, "float32") + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0, scale_w=2.0) + lv2 = bb.emit_te(topi.concatenate, (lv1, x), axis=1) + gv = bb.emit_output(bb.call_te(topi.add, lv2, relax.const(1, "float32"))) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, "float32") + + # Grouped function + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w = relax.Var("w", (1, 16, 32, 32), tensor_type) + p0 = relax.Var("p0", (), relax.DynTensorType(0, "float32")) + with bb.function("fused_upsampling_concatenate_add", [w, x, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.upsampling, w, scale_h=2.0, scale_w=2.0) + lv1 = bb.emit_te(topi.concatenate, (lv0, x), axis=1) + gv = bb.emit_output(bb.call_te(topi.add, lv1, p0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_upsampling_concatenate_add = bb.get().get_global_var( + "fused_upsampling_concatenate_add" + ) + + # Main function + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + gv = bb.emit_output( + relax.Call( + fused_upsampling_concatenate_add, (lv0, x, relax.const(1, "float32")) + ) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_tuple_root(): + """Test fusion case where Tuple node is the root in its group""" + + def before(): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, "float32") + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + x, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0, scale_w=2.0) + gv = bb.emit_output((lv1, x)) + bb.emit_func_output(gv) + + return bb.get() + + # The fusion is supposed to make no change. + _check(before(), before()) + + +def test_fuse_tuple_get_elemwise(): + def before(dim: int): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(2, "float32") + x = relax.Var("x", (1, dim), tensor_type) + w = relax.Var("w", (3 * dim, dim), tensor_type) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.dense, x, w) + lv1 = bb.emit_te(topi.split, lv0, indices_or_sections=3, axis=1) + lv2 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv3 = bb.emit_te(topi.sigmoid, lv2) + lv4 = bb.emit(relax.TupleGetItem(lv1, 1)) + lv5 = bb.emit_te(topi.tanh, lv4) + lv6 = bb.emit(relax.TupleGetItem(lv1, 2)) + lv7 = bb.emit_te(topi.exp, lv6) + lv8 = bb.emit_te(topi.multiply, lv5, lv7) + gv = bb.emit_output(bb.call_te(topi.add, lv3, lv8)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dim: int): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(2, "float32") + + # Grouped function + dense = relax.Var("dense", (1, 3 * dim), tensor_type) + with bb.function( + "fused_split_sigmoid_tanh_exp_multiply_add", [dense], attrs={"Primitive": 1} + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, dense, indices_or_sections=3, axis=1) + lv1 = bb.emit(relax.TupleGetItem(lv0, 0)) + lv2 = bb.emit_te(topi.sigmoid, lv1) + lv3 = bb.emit(relax.TupleGetItem(lv0, 1)) + lv4 = bb.emit_te(topi.tanh, lv3) + lv5 = bb.emit(relax.TupleGetItem(lv0, 2)) + lv6 = bb.emit_te(topi.exp, lv5) + lv7 = bb.emit_te(topi.multiply, lv4, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv2, lv7)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_split_sigmoid_tanh_exp_multiply_add = bb.get().get_global_var( + "fused_split_sigmoid_tanh_exp_multiply_add" + ) + + # Main function + x = relax.Var("x", (1, dim), tensor_type) + w = relax.Var("w", (3 * dim, dim), tensor_type) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.dense, x, w) + gv = bb.emit_output(relax.Call(fused_split_sigmoid_tanh_exp_multiply_add, (lv0,))) + bb.emit_func_output(gv) + + return bb.get() + + dim = 10 + _check(before(dim), expected(dim)) + + +def test_tuple_get_root(): + def before(dim: int): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(2, "float32") + x = relax.Var("x", (1, 3 * dim), tensor_type) + w = relax.Var("w", (dim, dim), tensor_type) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1) + lv1 = bb.emit(relax.TupleGetItem(lv0, 0)) + gv = bb.emit_output(bb.call_te(topi.nn.dense, lv1, w)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dim: int): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(2, "float32") + + # Grouped function + x = relax.Var("x", (1, 3 * dim), tensor_type) + with bb.function("fused_split", [x], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1) + gv = bb.emit_output(relax.TupleGetItem(lv0, 0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_split = bb.get().get_global_var("fused_split") + + # Main function + x = relax.Var("x", (1, 3 * dim), tensor_type) + w = relax.Var("w", (dim, dim), tensor_type) + with bb.function("main", [x, w]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(fused_split, (x,))) + gv = bb.emit_output(bb.call_te(topi.nn.dense, lv0, w)) + bb.emit_func_output(gv) + + return bb.get() + + dim = 10 + _check(before(dim), expected(dim)) + + +def test_tuple_intermediate(): + def before(): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, "float32") + + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.squeeze, x) + lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + lv2 = bb.emit_te(topi.squeeze, lv0) + lv3 = bb.emit_te(topi.add, lv2, relax.const(1, "float32")) + lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32")) + lv5 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1) + lv7 = bb.emit_te(topi.squeeze, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv7, relax.const(1, "float32"))) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, "float32") + constant_type = relax.DynTensorType(0, "float32") + + # Grouped function + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + p0 = relax.Var("p0", (), constant_type) + p1 = relax.Var("p1", (), constant_type) + p2 = relax.Var("p2", (), constant_type) + p3 = relax.Var("p3", (), constant_type) + p4 = relax.Var("p4", (), constant_type) + with bb.function( + "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1", + [x, p0, p1, p2, p3, p4], + attrs={"Primitive": 1}, + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.squeeze, x) + lv1 = bb.emit_te(topi.add, lv0, p0) + lv2 = bb.emit_te(topi.squeeze, lv0) + lv3 = bb.emit_te(topi.add, lv2, p1) + lv4 = bb.emit_te(topi.add, lv3, p2) + lv5 = bb.emit_te(topi.add, lv0, p3) + lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1) + lv7 = bb.emit_te(topi.squeeze, lv6) + gv = bb.emit_output(bb.call_te(topi.add, lv7, p4)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var( + "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1" + ) + + # Main func + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output( + relax.Call( + fused_func, + ( + x, + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + ), + ) + ) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_tuple_consecutive(): + def before(): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, "float32") + + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv1 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv2 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1) + lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32")) + lv5 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv6 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv7 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1) + lv9 = bb.emit_te(topi.add, lv8, relax.const(1, "float32")) + lv10 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv11 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv12 = bb.emit_te(topi.add, x, relax.const(1, "float32")) + lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1) + lv14 = bb.emit_te(topi.add, lv13, relax.const(1, "float32")) + lv15 = bb.emit_te(topi.concatenate, (lv4, lv9, lv14), axis=1) + lv16 = bb.emit_te( + topi.nn.pool2d, + lv15, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + lv17 = bb.emit_te(topi.add, lv16, relax.const(1, "float32")) + lv18 = bb.emit_te(topi.add, lv17, relax.const(1, "float32")) + gv = bb.emit_output((lv17, lv18)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, "float32") + constant_type = relax.DynTensorType(0, "float32") + + # Grouped function 1 + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + p0 = relax.Var("p0", (), constant_type) + p1 = relax.Var("p1", (), constant_type) + p2 = relax.Var("p2", (), constant_type) + p3 = relax.Var("p3", (), constant_type) + p4 = relax.Var("p4", (), constant_type) + p5 = relax.Var("p5", (), constant_type) + p6 = relax.Var("p6", (), constant_type) + p7 = relax.Var("p7", (), constant_type) + p8 = relax.Var("p8", (), constant_type) + p9 = relax.Var("p9", (), constant_type) + p10 = relax.Var("p10", (), constant_type) + p11 = relax.Var("p11", (), constant_type) + with bb.function( + "fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1", + [x, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11], + attrs={"Primitive": 1}, + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.add, x, p1) + lv2 = bb.emit_te(topi.add, x, p2) + lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1) + lv4 = bb.emit_te(topi.add, lv3, p3) + lv5 = bb.emit_te(topi.add, x, p4) + lv6 = bb.emit_te(topi.add, x, p5) + lv7 = bb.emit_te(topi.add, x, p6) + lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1) + lv9 = bb.emit_te(topi.add, lv8, p7) + lv10 = bb.emit_te(topi.add, x, p8) + lv11 = bb.emit_te(topi.add, x, p9) + lv12 = bb.emit_te(topi.add, x, p10) + lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1) + lv14 = bb.emit_te(topi.add, lv13, p11) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv4, lv9, lv14), axis=1)) + bb.emit_func_output(gv) + + # Grouped function 2 + concat = relax.Var("concat", (1, 144, 64, 64), tensor_type) + p0 = relax.Var("p0", (), constant_type) + with bb.function("fused_pool2d_add2", [concat, p0], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.pool2d, + concat, + kernel=(2, 2), + stride=(2, 2), + dilation=(1, 1), + padding=(0, 0, 0, 0), + pool_type="max", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, p0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_func1 = mod.get_global_var( + "fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1" + ) + fused_func2 = mod.get_global_var("fused_pool2d_add2") + + # Main function + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit( + relax.Call( + fused_func1, + ( + x, + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + relax.const(1, "float32"), + ), + ) + ) + lv1 = bb.emit(relax.Call(fused_func2, (lv0, relax.const(1, "float32")))) + lv2 = bb.emit_te(topi.add, lv1, relax.const(1, "float32")) + gv = bb.emit_output((lv1, lv2)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_inception_like(): + def before(): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, "float32") + + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w0 = relax.Var("w0", (16, 16, 3, 3), tensor_type) + w1 = relax.Var("w1", (16, 16, 3, 3), tensor_type) + w2 = relax.Var("w2", (16, 32, 3, 3), tensor_type) + w3 = relax.Var("w3", (16, 32, 3, 3), tensor_type) + with bb.function("main", [x, w0, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.conv2d, x, w0, strides=1, padding=1, dilation=1) + lv1 = bb.emit_te(topi.nn.relu, lv0) + lv2 = bb.emit_te(topi.nn.conv2d, x, w1, strides=1, padding=1, dilation=1) + lv3 = bb.emit_te(topi.nn.relu, lv2) + lv4 = bb.emit_te(topi.concatenate, (lv1, lv3), axis=1) + lv5 = bb.emit_te(topi.nn.conv2d, lv4, w2, strides=1, padding=1, dilation=1) + lv6 = bb.emit_te(topi.nn.relu, lv5) + lv7 = bb.emit_te(topi.nn.conv2d, lv4, w3, strides=1, padding=1, dilation=1) + lv8 = bb.emit_te(topi.nn.relu, lv7) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv6, lv8), axis=1)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, "float32") + + # Grouped function 1 + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w = relax.Var("w", (16, 16, 3, 3), tensor_type) + with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0)) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", (1, 32, 64, 64), tensor_type) + w = relax.Var("w", (16, 32, 3, 3), tensor_type) + with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_relu1 = mod.get_global_var("fused_conv2d_relu") + fused_conv2d_relu2 = mod.get_global_var("fused_conv2d1_relu") + + # Main function + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w0 = relax.Var("w0", (16, 16, 3, 3), tensor_type) + w1 = relax.Var("w1", (16, 16, 3, 3), tensor_type) + w2 = relax.Var("w2", (16, 32, 3, 3), tensor_type) + w3 = relax.Var("w3", (16, 32, 3, 3), tensor_type) + with bb.function("main", [x, w0, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w0))) + lv1 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w1))) + lv2 = bb.emit_te(topi.concatenate, (lv0, lv1), axis=1) + lv3 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w2))) + lv4 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w3))) + gv = bb.emit_output(bb.call_te(topi.concatenate, (lv3, lv4), axis=1)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_fuse_parallel_injective(): + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", (10, 20), relax.DynTensorType(2, "int32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, "int32")) + lv1 = bb.emit_te(topi.squeeze, lv0) + lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0]) + lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0]) + gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3)) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + + # Grouped function + x = relax.Var("x", (10, 20), relax.DynTensorType(2, "int32")) + p0 = relax.Var("p0", (), relax.DynTensorType(0, "int32")) + with bb.function( + "fused_add_squeeze_transpose_transpose1_left_shift", [x, p0], attrs={"Primitive": 1} + ): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.squeeze, lv0) + lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0]) + lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0], primfunc_name_hint="transpose1") + gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3)) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var("fused_add_squeeze_transpose_transpose1_left_shift") + + # Main function + x = relax.Var("x", (10, 20), relax.DynTensorType(2, "int32")) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_func, (x, relax.const(1, "int32")))) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +def test_softmax(): + """Test if softmax can be fused with following ops.""" + + def before(): + bb = relax.BlockBuilder() + + x = relax.Var("x", (16, 16), relax.DynTensorType(2, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.softmax, x) + gv = bb.emit_output(bb.call_te(topi.cast, lv0, dtype="float16")) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + constant_type = relax.DynTensorType(2, "float32") + + # Grouped function + x = relax.Var("x", (16, 16), constant_type) + with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.nn.softmax, x) + gv = bb.emit_output(bb.call_te(topi.cast, lv0, dtype="float16")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + fused_func = bb.get().get_global_var("fused_softmax_cast") + + # Main function + x = relax.Var("x", (16, 16), constant_type) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_func, (x,))) + bb.emit_func_output(gv) + + return bb.get() + + _check(before(), expected()) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py new file mode 100644 index 0000000000..40c9ce98e8 --- /dev/null +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -0,0 +1,574 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import sys +import tvm +from tvm import topi +from tvm import relax + + +def _check(mod_before, mod_expected): + mod = relax.transform.FuseTIR()(mod_before) + tvm.ir.assert_structural_equal(mod, mod_expected) + + +def test_simple(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + p0 = relax.Var("p0", (), relax.DynTensorType(0, "float32")) + + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, p0) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") + + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + def fused_add_exp_squeeze(x, p0): + add = topi.add(x, p0) + exp = topi.exp(add) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + p0 = relax.Var("p0", (), relax.DynTensorType(0, "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_add_exp_squeeze, x, p0)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_conv2d_fuse(): + def before(dtype): + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, dtype) + + # Grouped function 1 + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w = relax.Var("w", (16, 16, 3, 3), tensor_type) + p0 = relax.Var("p0", (), relax.DynTensorType(0, dtype)) + with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=1, + dilation=1, + primfunc_name_hint="conv2d", + ) + lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1") + gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Grouped function 2 + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w = relax.Var("w", (16, 16, 1, 1), tensor_type) + y = relax.Var("y", (1, 16, 64, 64), tensor_type) + with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te( + topi.nn.conv2d, + x, + w, + strides=1, + padding=0, + dilation=1, + primfunc_name_hint="conv2d1", + ) + gv = bb.emit_output(bb.call_te(topi.add, lv0, y, primfunc_name_hint="add2")) + bb.emit_func_output(gv) + + # Get the global variables of the grouped functions + mod = bb.get() + fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2") + fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2") + + # Main function + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w1 = relax.Var("w1", (16, 16, 3, 3), tensor_type) + w2 = relax.Var("w2", (16, 16, 1, 1), tensor_type) + w3 = relax.Var("w3", (16, 16, 3, 3), tensor_type) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, relax.const(1, dtype)])) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2, lv2])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(dtype): + def fused_conv2d_add1_add2(x, w, p): + conv = topi.nn.conv2d(x, w, strides=1, padding=1, dilation=1) + add = topi.add(p, conv) + return topi.add(conv, add) + + def fused_conv2d1_add2(x, w, p): + conv = topi.nn.conv2d(x, w, strides=1, padding=0, dilation=1) + return topi.add(conv, p) + + bb = relax.BlockBuilder() + tensor_type = relax.DynTensorType(4, dtype) + + # Main function + x = relax.Var("x", (1, 16, 64, 64), tensor_type) + w1 = relax.Var("w1", (16, 16, 3, 3), tensor_type) + w2 = relax.Var("w2", (16, 16, 1, 1), tensor_type) + w3 = relax.Var("w3", (16, 16, 3, 3), tensor_type) + with bb.function("main", [x, w1, w2, w3]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) + lv1 = bb.emit_te(fused_conv2d_add1_add2, lv0, w1, relax.const(1, dtype)) + lv2 = bb.emit_te( + topi.nn.conv2d, + lv1, + w3, + strides=1, + padding=1, + dilation=1, + ) + gv = bb.emit_output(bb.call_te(fused_conv2d1_add2, lv1, w2, lv2)) + bb.emit_func_output(gv) + + return bb.get() + + _check(before("float32"), expected("float32")) + + +def test_two_subfunction(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv1 = bb.emit_te(topi.exp, x1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_squeeze") + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x])) + lv2 = bb.emit(relax.Call(func_gv, [lv])) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_squeeze(x): + exp = topi.exp(x) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_squeeze, x) + lv2 = bb.emit_te(fused_exp_squeeze, lv) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_same_primfunc(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("fused_exp_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv1 = bb.emit_te(topi.exp, x1) + lv2 = bb.emit_te(topi.exp, lv1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv2)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_exp_squeeze") + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x])) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_exp_squeeze(x): + exp = topi.exp(x) + exp = topi.exp(exp) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_exp_squeeze, x) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_tuple_as_param(): + dyn_tensor_type = relax.DynTensorType(1, "float32") + tuple_type = relax.TupleType([dyn_tensor_type, dyn_tensor_type]) + tuple_shape = relax.Tuple([relax.ShapeExpr([10]), relax.ShapeExpr([10])]) + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", tuple_shape, tuple_type) + with bb.function("fused_exp_add", [x], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv2 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.add, lv2, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_add") + x = relax.Var("x", tuple_shape, tuple_type) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_add(x1, x2): + exp = topi.exp(x1) + return topi.add(exp, x2) + + bb = relax.BlockBuilder() + dyn_tensor_type = relax.DynTensorType(1, "float32") + tuple_type = relax.TupleType([dyn_tensor_type, dyn_tensor_type]) + tuple_shape = relax.Tuple([relax.ShapeExpr([10]), relax.ShapeExpr([10])]) + x = relax.Var("x", tuple_shape, tuple_type) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + gv = bb.emit_output(bb.call_te(fused_exp_add, lv0, lv1)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_nested_tuple_as_param(): + dyn_tensor_type = relax.DynTensorType(1, "float32") + tuple_type = relax.TupleType( + [dyn_tensor_type, relax.TupleType([dyn_tensor_type, dyn_tensor_type])] + ) + shape = relax.ShapeExpr([10]) + tuple_shape = relax.Tuple([shape, relax.Tuple([shape, shape])]) + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", tuple_shape, tuple_type) + with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv0_exp = bb.emit_te(topi.exp, lv0) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv1_0 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv1_1 = bb.emit(relax.TupleGetItem(lv1, 1)) + lv2 = bb.emit_te(topi.add, lv1_0, lv1_1) + gv = bb.emit_output(bb.call_te(topi.add, lv0_exp, lv2)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_add_add") + x = relax.Var("x", tuple_shape, tuple_type) + with bb.function("main", [x]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_add_add(x1, x2, x3): + exp = topi.exp(x1) + add = topi.add(x2, x3) + return topi.add(exp, add) + + bb = relax.BlockBuilder() + x = relax.Var("x", tuple_shape, tuple_type) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.TupleGetItem(x, 0)) + lv1 = bb.emit(relax.TupleGetItem(x, 1)) + lv2 = bb.emit(relax.TupleGetItem(lv1, 0)) + lv3 = bb.emit(relax.TupleGetItem(lv1, 1)) + gv = bb.emit_output(bb.call_te(fused_exp_add_add, lv0, lv2, lv3)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_call_tir_in_main(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}): + with bb.dataflow(): + lv = bb.emit_te(topi.exp, x1) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_exp_squeeze") + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit(relax.Call(func_gv, [x])) + lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32")) + gv = bb.emit_output(lv1) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_exp_squeeze(x): + exp = topi.exp(x) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_exp_squeeze, x) + lv2 = bb.emit_te(topi.add, lv, relax.const(1, "float32")) + gv = bb.emit_output(lv2) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_const_in_argument(): + def before(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", [10, 20], relax.DynTensorType(2, "float32")) + x2 = relax.Var("x2", [], relax.DynTensorType(0, "float32")) + with bb.function("fused_add_exp_squeeze", [x1, x2], attrs={"Primitive": True}): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x1, x2) + lv1 = bb.emit_te(topi.exp, lv0) + gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_add_exp_squeeze") + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit(relax.Call(func_gv, [x, relax.const(1, "float32")])) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_add_exp_squeeze(x, y): + add = topi.add(x, y) + exp = topi.exp(add) + squeeze = topi.squeeze(exp) + return squeeze + + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x]): + with bb.dataflow(): + lv = bb.emit_te(fused_add_exp_squeeze, x, relax.const(1, "float32")) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_tuple_output(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + p0 = relax.Var("p0", (), relax.DynTensorType(0, "float32")) + + with bb.function("fused_add_exp", [x, p0], attrs={"Primitive": True}): + with bb.dataflow(): + gv0 = bb.emit_output(bb.call_te(topi.add, x, p0)) + gv1 = bb.emit_output(bb.call_te(topi.exp, gv0)) + bb.emit_func_output(relax.Tuple([gv0, gv1])) + fused_add_exp = bb.get().get_global_var("fused_add_exp") + + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add_exp, [x, p0])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + def fused_add_exp(x, p0): + add = topi.add(x, p0) + exp = topi.exp(add) + return add, exp + + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + p0 = relax.Var("p0", (), relax.DynTensorType(0, "float32")) + with bb.function("main", [x, p0]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_add_exp, x, p0)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_with_immediate_tuple(): + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + y = relax.Var("y", [10, 20], relax.DynTensorType(2, "float32")) + + with bb.function("fused_add", [x, y], attrs={"Primitive": True}): + with bb.dataflow(): + lv_tuple = bb.emit(relax.Tuple([x, relax.Tuple([x, y])])) + lv_x = bb.emit(relax.TupleGetItem(lv_tuple, 0)) + lv0 = bb.emit(relax.TupleGetItem(lv_tuple, 1)) + lv_y = bb.emit(relax.TupleGetItem(lv0, 1)) + gv = bb.emit_output(bb.call_te(topi.add, lv_x, lv_y)) + bb.emit_func_output(gv) + fused_add = bb.get().get_global_var("fused_add") + + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + y = relax.Var("y", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x, y]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(fused_add, [x, y])) + bb.emit_func_output(gv) + + return bb.get() + + def expected(): + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + y = relax.Var("y", [10, 20], relax.DynTensorType(2, "float32")) + with bb.function("main", [x, y]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(topi.add, x, y, primfunc_name_hint="fused_add")) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +def test_fuse_return_partial_result(): + def te_argmax_idx_val(val): + from tvm import te + + def f_combine(x, y): + lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) + return lhs, rhs + + def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): + return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1) + + argmax = te.comm_reducer(f_combine, f_identity, name="argmax") + m, n = val.shape + k = te.reduce_axis((0, n), "k") + max_idx, max_val = te.compute( + (m,), lambda i: argmax((k.var, val[i, k]), axis=k), name="argmax" + ) + return max_idx, max_val + + def before(): + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + offset = relax.Var("offset", [10], relax.DynTensorType(1, "int32")) + with bb.function("fused_argmax_add", [x, offset], attrs={"Primitive": True}): + with bb.dataflow(): + lv = bb.emit_te(te_argmax_idx_val, x) + idx = bb.emit(relax.TupleGetItem(lv, 0)) + gv = bb.emit_output(bb.call_te(topi.add, idx, offset)) + bb.emit_func_output(gv) + mod = bb.get() + + func_gv = mod.get_global_var("fused_argmax_add") + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + offset = relax.Var("x", [10], relax.DynTensorType(1, "int32")) + with bb.function("main", [x, offset]): + with bb.dataflow(): + gv = bb.emit_output(relax.Call(func_gv, [x, offset])) + bb.emit_func_output(gv) + return bb.get() + + def expected(): + def fused_argmax_add(x, offset): + idx, value = te_argmax_idx_val(x) + idx = topi.add(idx, offset) + return idx + + bb = relax.BlockBuilder() + x = relax.Var("x", [10, 20], relax.DynTensorType(2, "float32")) + offset = relax.Var("offset", [10], relax.DynTensorType(1, "int32")) + with bb.function("main", [x, offset]): + with bb.dataflow(): + gv = bb.emit_output(bb.call_te(fused_argmax_add, x, offset)) + bb.emit_func_output(gv) + return bb.get() + + _check(before(), expected()) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py new file mode 100644 index 0000000000..996a25eea1 --- /dev/null +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -0,0 +1,295 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations +import pytest +import tvm +from tvm import relax +from tvm.runtime.object import Object +import tvm.script +from tvm.script import relax as R, tir as T +from tvm.relax import transform +from tvm.ir.base import assert_structural_equal + + +def _check_equal(x, y): + tvm.ir.assert_structural_equal(x, y) + tvm.ir.assert_structural_equal(y, x) + + xhash = tvm.ir.structural_hash(x) + yhash = tvm.ir.structural_hash(y) + + assert xhash == yhash + + +def _check_save_roundtrip(x): + y = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, y) + + +def test_basic(): + # the target IRModule + @tvm.script.ir_module + class Expected: + @R.function + def lifted_func_0(x2: Tensor((10, 5), "float32"), y2: Tensor((10, 5), "float32")) -> Tensor: + s: Tensor((10, 5), "float32") = relax.add(x2, y2) + return s + + @R.function + def main( + x1: Tensor((10, 5), "float32"), y1: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + inner = lifted_func_0 + gv1 = inner(x1, y1) + return gv1 + + @tvm.script.ir_module + class Before: + @R.function + def main( + x1: Tensor((10, 5), "float32"), y1: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + @R.function + def inner( + x2: Tensor((10, 5), "float32"), y2: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + s: Tensor((10, 5), "float32") = relax.add(x2, y2) + return s + + gv1: Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + before = Before + expected = Expected + # Perform Lambda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 2 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_closure(): + # the expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor((2, 3), "float32"), y: Tensor((2, 3), "float32")): + outer_func = lifted_func_0 + in_call = outer_func(x) + res = relax.invoke_closure(in_call, (y,), type_args=(Tensor(ndim=2, dtype="float32"))) + return res + + @R.function + def lifted_func_1(x1: Tensor((2, 3), "float32"), c1: Tensor((2, 3), "float32")): + r_1: Tensor((2, 3), "float32") = relax.add(x1, c1) + return r_1 + + @R.function + def lifted_func_0(y: Tensor((2, 3), "float32")): + return relax.make_closure(lifted_func_1, (y,)) + + # IRModule to perform Lambda Lifting + @tvm.script.ir_module + class Before: + @R.function + def main( + x: Tensor((2, 3), "float32"), y: Tensor((2, 3), "float32") + ) -> Tensor((2, 3), "float32"): + @R.function + def outer_func(c1: Tensor((2, 3), "float32")): + @R.function + def inner_func(x1: Tensor((2, 3), "float32")): + s: Tensor((2, 3), "float32") = relax.add(x1, c1) + return s + + return inner_func + + in_call = outer_func(x) + res = in_call(y) + return res + + before = Before + after = transform.LambdaLift()(before) + expected = Expected + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_recursive(): + # the expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def lifted_func_0( + i: Tensor((), "int32"), s: Tensor((2, 3), "float32"), x: Tensor((2, 3), "float32") + ) -> Tensor((2, 3), "float32"): + cond: Tensor((), "bool") = relax.call_packed( + "test.vm.less", i, relax.const(10), type_args=(Tensor(ndim=0, dtype="bool")) + ) + c: Tensor((), "int32") = relax.const(1, dtype="int32") + if cond: + new_i: Tensor((), "int32") = relax.add(i, c) + new_s: Tensor((2, 3), "float32") = relax.add(s, x) + r = lifted_func_0(new_i, new_s, x) + else: + r = s + return r + + @R.function + def main(x: Tensor((2, 3), "float32")) -> Tensor: + while_loop = relax.make_closure(lifted_func_0, (x,)) + gv = relax.invoke_closure( + while_loop, (relax.const(0), x), type_args=(Tensor(ndim=2, dtype="float32")) + ) + return gv + + # the IRModule to apply lambda lifting + @tvm.script.ir_module + class Before: + @R.function + def main(x: Tensor((2, 3), "float32")) -> Tensor: + @R.function + def while_loop( + i: Tensor((), "int32"), s: Tensor((2, 3), "float32") + ) -> Tensor((2, 3), "float32"): + cond: Tensor((), "bool") = relax.call_packed( + "test.vm.less", i, relax.const(10), type_args=(Tensor(ndim=0, dtype="bool")) + ) + c: Tensor((), "int32") = relax.const(1, dtype="int32") + if cond: + new_i: Tensor((), "int32") = relax.add(i, c) + new_s: Tensor((2, 3), "float32") = relax.add(s, x) + r: Tensor((2, 3), "float32") = while_loop(new_i, new_s) + else: + r: Tensor((2, 3), "float32") = s + return r + + gv: Tensor((2, 3), "float32") = while_loop(relax.const(0), x) + return gv + + before = Before + expected = Expected + # Perform Lamda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 2 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_multi_func(): + # expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def glob_func_1( + x1: Tensor((10, 5), "float32"), y1: Tensor((10, 5), "float32") + ) -> Tensor(None, "float32", ndim=2): + inner = lifted_func_1 + gv1 = inner(x1, y1) + return gv1 + + @R.function + def glob_func_2( + x11: Tensor((10, 5), "float32"), y11: Tensor((10, 5), "float32") + ) -> Tensor(None, "float32", ndim=2): + inner1 = lifted_func_0 + gv11 = inner1(x11, y11) + return gv11 + + @R.function + def lifted_func_0( + x2: Tensor((10, 5), "float32"), y2: Tensor((10, 5), "float32") + ) -> Tensor(None, "float32", ndim=2): + s: Tensor((10, 5), "float32") = relax.add(x2, y2) + return s + + @R.function + def lifted_func_1( + x21: Tensor((10, 5), "float32"), y21: Tensor((10, 5), "float32") + ) -> Tensor(None, "float32", ndim=2): + s1: Tensor((10, 5), "float32") = relax.add(x21, y21) + return s1 + + # the IRModule to apply lambda lifting + @tvm.script.ir_module + class Before: + @R.function + def glob_func_1( + x1: Tensor((10, 5), "float32"), y1: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + @R.function + def inner( + x2: Tensor((10, 5), "float32"), y2: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + s: Tensor((10, 5), "float32") = relax.add(x2, y2) + return s + + gv1: Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @R.function + def glob_func_2( + x1: Tensor((10, 5), "float32"), y1: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + @R.function + def inner( + x2: Tensor((10, 5), "float32"), y2: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + s: Tensor((10, 5), "float32") = relax.add(x2, y2) + return s + + gv1: Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + before = Before + expected = Expected + # Perform Lamda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 4 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_no_local_func(): + @tvm.script.ir_module + class Before: + @T.prim_func + def sub( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("sub"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] - B[vi, vj] + + @R.function + def before(c0: Tensor((16, 16), "float32"), x: Tensor((_, _), "float32")): + s = relax.call_tir(sub, (c0, x), (16, 16), dtype="float32") + return s + + before = Before + # Perform lambda lifting + after = transform.LambdaLift()(before) + # No local functions are lifted + assert_structural_equal(after, before, map_free_vars=True) + _check_save_roundtrip(after) + + +if __name__ == "__main__": + pytest.main((__file__)) diff --git a/tests/python/relax/test_transform_lower_with_op_strategy.py b/tests/python/relax/test_transform_lower_with_op_strategy.py new file mode 100644 index 0000000000..148df11014 --- /dev/null +++ b/tests/python/relax/test_transform_lower_with_op_strategy.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import tempfile + +import numpy as np +import pytest +import tvm +import tvm.script +import tvm.testing +from tvm import meta_schedule as ms +from tvm import relax +from tvm.relax.testing import transform +from tvm.script import relax as R +from tvm.target import Target + + +@tvm.script.ir_module +class InputModule: + @R.function + def main( + x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + gv0 = R.multiply(x, w) + gv1 = R.add(x, gv0) + return gv1 + + +def build_and_run(mod, target, dev, np_inputs): + inputs = [tvm.nd.array(np_input, dev) for np_input in np_inputs] + with tempfile.TemporaryDirectory() as work_dir: + db = ms.relax_integration.tune_relax( + mod=mod, + params=None, + target=target, + work_dir=work_dir, + num_trials_per_iter=20, + max_trials_global=20, + task_scheduler="round-robin", + ) + ex = ms.relax_integration.compile_relax(db, mod, target, params=None) + vm = relax.VirtualMachine(ex, dev) + vm["main"](*inputs) + + +def _test_lowering(target, dev): + mod = InputModule + assert mod + with tvm.transform.PassContext(opt_level=3): + out_mod = transform.LowerWithRelayOpStrategyPass(target)(mod) + + input_shape = (16, 16) + np_inputs = [ + np.random.rand(*input_shape).astype(np.float32), + np.random.rand(*input_shape).astype(np.float32), + ] + build_and_run(out_mod, target, dev, np_inputs) + + +def test_lowering_cpu(target_str="llvm --num-cores=16"): + _test_lowering(Target(target_str), tvm.cpu()) + + +@tvm.testing.requires_gpu +def test_lowering_gpu(target_str="nvidia/nvidia-t4"): + _test_lowering(Target(target_str), tvm.cuda()) + + +if __name__ == "__main__": + test_lowering_cpu() + test_lowering_gpu() diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py new file mode 100644 index 0000000000..739cd42847 --- /dev/null +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import tempfile + +import pytest +import tvm +import tvm.meta_schedule as ms +from tvm import relax +from tvm.ir import transform +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.relax.transform.tuning_api import Trace +from tvm.script import relax as R +from tvm.script import tir as T + +target = tvm.target.Target("llvm --num-cores=16") + + +@tvm.script.ir_module +class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + k = T.var("int32") + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + C = T.match_buffer(z, (32, 32)) + + for (i0, j0, k0) in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) + with T.init(): + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + @T.prim_func + def tir_relu(x: T.handle, y: T.handle): + T.func_attr({"global_symbol": "tir_relu"}) + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + for (i, j) in T.grid(32, 32): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.max(A[vi, vj], 0.0) + + @R.function + def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") + lv1 = R.call_tir(tir_relu, (lv0), (32, 32), dtype="float32") + relax.output(lv1) + return lv1 + + +# TODO(@sunggg): determine how to pass MS database object across different passes. +# PassContext might be an option, but we already have TuningAPI database. +# (MS database and TuningAPI database will be unified in the future) +# For now, we only support default JSON database config. +def test_ms_tuning_irmodule(): + + mod = InputModule + assert isinstance(mod, IRModule) + + with tempfile.TemporaryDirectory() as work_dir: + with target, transform.PassContext(trace=Trace(mod), opt_level=0): + tuning_pass = relax.transform.MetaScheduleTuneIRMod( + params={}, work_dir=work_dir, max_trials_global=4 + ) + out_mod = tuning_pass(mod) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + tvm.ir.assert_structural_equal(mod, out_mod) + + application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir) + + out_mod = application_pass(mod) + assert not tvm.ir.structural_equal(mod, out_mod) + + +def test_ms_tuning_primfunc(): + mod = InputModule + assert isinstance(mod, IRModule) + with tempfile.TemporaryDirectory() as work_dir: + with target, transform.PassContext(trace=Trace(mod), opt_level=0): + tuning_pass = relax.transform.MetaScheduleTuneTIR( + work_dir=work_dir, max_trials_global=4 + ) + out_mod = tuning_pass(mod) + assert PassContext.current().get_trace_stack_size() == 1 + # TODO (@sunggg): Need to determine how to track subgraph-level tuning traces. + # Currently, we don't track this so the trace size. Revisit this later. + tvm.ir.assert_structural_equal(mod, out_mod) + + application_pass = relax.transform.MetaScheduleApplyDatabase(work_dir) + out_mod = application_pass(mod) + assert not tvm.ir.structural_equal(mod, out_mod) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_remove_unused_funcs.py b/tests/python/relax/test_transform_remove_unused_funcs.py new file mode 100644 index 0000000000..673cb83aea --- /dev/null +++ b/tests/python/relax/test_transform_remove_unused_funcs.py @@ -0,0 +1,271 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +import pytest +import tvm +import tvm.testing +from tvm import relax +import tvm.script +from tvm.script import tir as T, relax as R + +# TODO (sunggg): +# Currently, parser sets each function name as the global symbol by default. +# And printer uses the global symbol to print out the function name. +# This is temproray and will improve in the future. + + +def check_if_func_exists(mod, func_name): + gvs = [str(gv) for gv in mod.get_global_vars()] + return ("@" + func_name) in gvs + + +def test_unused_relax_func(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_add( + x: T.Buffer[(16, 16), "float32"], + y: T.Buffer[(16, 16), "float32"], + z: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func(x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32")): + gv0 = relax.add(x, w) + return gv0 + + @R.function + def main( + x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + gv0 = R.call_tir(tir_add, (x, w), (16, 16), dtype="float32") + return gv0 + + mod = InputModule + assert mod + # RemoveUnusedFunction pass won't remove the function with global symbol for the external reference. + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "tir_add") + assert check_if_func_exists(new_mod, "unused_func") + + # Remove global symbol from the function. + mod["unused_func"] = mod["unused_func"].without_attr("global_symbol") + + # Then, this removes the unused function without any global symbol. + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "tir_add") + assert not check_if_func_exists(new_mod, "unused_func") + + +def test_unused_relax_func_custom_entry_func(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_add( + x: T.Buffer[(16, 16), "float32"], + y: T.Buffer[(16, 16), "float32"], + z: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func(x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32")): + gv0 = relax.add(x, w) + return gv0 + + @R.function + def foo( + x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + gv0 = R.call_tir(tir_add, (x, w), (16, 16), dtype="float32") + return gv0 + + mod = InputModule + assert mod + + # RemoveUnusedFunction pass won't remove the function with global symbol for the external reference. + # Test entry function other than "main". + new_mod = relax.transform.RemoveUnusedFunctions(entry_functions=["foo"])(mod) + assert check_if_func_exists(new_mod, "foo") + assert check_if_func_exists(new_mod, "tir_add") + assert check_if_func_exists(new_mod, "unused_func") + + # Remove global symbol from the function. + mod["unused_func"] = mod["unused_func"].without_attr("global_symbol") + + # Then, this removes the unused function without any global symbol. + new_mod = relax.transform.RemoveUnusedFunctions(entry_functions=["foo"])(mod) + assert check_if_func_exists(new_mod, "foo") + assert check_if_func_exists(new_mod, "tir_add") + assert not check_if_func_exists(new_mod, "unused_func") + + +def test_unused_relax_func_symbolic_shape(): + # Test with relax function w/ symbolic shape. + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_add( + x: T.Buffer[(16, 16), "float32"], + y: T.Buffer[(16, 16), "float32"], + z: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func(x: Tensor((m, n), "float32"), w: Tensor((n, k), "float32")): + gv0 = relax.add(x, w) + return gv0 + + @R.function + def main(x: Tensor((m, n), "float32"), w: Tensor((n, k), "float32")): + gv0 = R.call_tir(tir_add, (x, w), (m, k), dtype="float32") + return gv0 + + mod = InputModule + assert mod + + # RemoveUnusedFunction pass won't remove the function with global symbol for the external reference. + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "tir_add") + assert check_if_func_exists(new_mod, "unused_func") + + # Remove global symbol from the unused function + mod["unused_func"] = mod["unused_func"].without_attr("global_symbol") + # Remove unused function before shape lowering. + # Test entry function other than "main". + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "tir_add") + assert not check_if_func_exists(new_mod, "unused_func") + + # Remove unused function after shape lowering. + # Shape lowering will inject several shape-related global functions. + # We need to make sure unused function removal pass does not remove those functions. + shape_lowered_mod = relax.transform.VMShapeLower()(mod) + new_mod = relax.transform.RemoveUnusedFunctions()(shape_lowered_mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "tir_add") + assert check_if_func_exists(new_mod, "shape_func") # injected by VMShapeLower pass + assert not check_if_func_exists(new_mod, "unused_func") + + +def test_unused_prim_func(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def unused_func( + x: T.Buffer[(16, 16), "float32"], + y: T.Buffer[(16, 16), "float32"], + z: T.Buffer[(16, 16), "float32"], + ) -> None: + T.func_attr({"global_symbol": "tir_unused"}) + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def relax_add(x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32")): + gv0 = relax.add(x, w) + return gv0 + + @R.function + def main( + x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + gv0 = relax_add(x, w) + return gv0 + + mod = InputModule + assert mod + # RemoveUnusedFunction pass won't remove the function with global symbol for the external reference. + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "relax_add") + assert check_if_func_exists(new_mod, "unused_func") + + # Remove global symbol from the unused function + mod["unused_func"] = mod["unused_func"].without_attr("global_symbol") + + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "relax_add") + assert not check_if_func_exists(new_mod, "unused_func") + + +def test_multiple_unused_funcs(): + @tvm.script.ir_module + class InputModule: + @T.prim_func + def unused_func1( + x: T.Buffer[(16, 16), "float32"], + y: T.Buffer[(16, 16), "float32"], + z: T.Buffer[(16, 16), "float32"], + ) -> None: + T.func_attr({"global_symbol": "tir_unused"}) + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + z[vi, vj] = x[vi, vj] + y[vi, vj] + + @R.function + def unused_func2(x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32")): + gv0 = relax.add(x, w) + return gv0 + + @R.function + def main( + x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + gv0 = relax.add(x, w) + return gv0 + + mod = InputModule + assert mod + + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + assert check_if_func_exists(new_mod, "unused_func1") + assert check_if_func_exists(new_mod, "unused_func2") + + # Remove global symbol from unused functions + mod["unused_func1"] = mod["unused_func1"].without_attr("global_symbol") + mod["unused_func2"] = mod["unused_func2"].without_attr("global_symbol") + + new_mod = relax.transform.RemoveUnusedFunctions()(mod) + assert check_if_func_exists(new_mod, "main") + assert not check_if_func_exists(new_mod, "unused_func1") + assert not check_if_func_exists(new_mod, "unused_func2") + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_well_formed.py b/tests/python/relax/test_transform_well_formed.py new file mode 100644 index 0000000000..55a31547b6 --- /dev/null +++ b/tests/python/relax/test_transform_well_formed.py @@ -0,0 +1,214 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +from tvm import tir +from tvm import relax as rx + +m = tir.Var("m", "int32") +n = tir.Var("n", "int32") +type_anno = rx.DynTensorType(ndim=2, dtype="float16") +bool_type_anno = rx.DynTensorType(ndim=0, dtype="bool") +x = rx.Var("x", [m, n], type_anno) +cond = rx.Var("cond", [], bool_type_anno) + + +def build_function(blocks, params=[]): + """Returns relax.function with given blocks""" + seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var) + ret_type = rx.DynTensorType(ndim=-1, dtype="float32") + ret_shape = rx.RuntimeDepShape() + func = rx.Function([x, cond] + params, seq_expr, ret_type, ret_shape).with_attr( + "global_symbol", "foo" + ) + return func + + +def test_var(): + # Error: Var gv0 is not defined + gv0 = rx.Var("gv0", [m, n], type_anno) + gv1 = rx.Var("gv1", [m, n], type_anno) + call_node = rx.op.add(x, gv0) + bindings = [rx.VarBinding(gv1, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + # Error: Var gv0 is defined more than once + gv0 = rx.Var("gv0", [m, n], type_anno) + call_node = rx.op.add(x, x) + call_node2 = rx.op.multiply(x, x) + bindings = [rx.VarBinding(gv0, call_node), rx.VarBinding(gv0, call_node2)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + +def test_dataflow_var(): + # Error: DataflowVar lv0 is not defined + lv0 = rx.DataflowVar("lv0", [m, n], type_anno) + gv0 = rx.Var("gv0", [m, n], type_anno) + call_node = rx.op.add(x, lv0) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.DataflowBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + # Error: DataflowVar gv0 is defined more than once + lv0 = rx.DataflowVar("lv0", [m, n], type_anno) + call_node = rx.op.add(x, x) + call_node2 = rx.op.multiply(x, x) + bindings = [rx.VarBinding(lv0, call_node), rx.VarBinding(lv0, call_node2)] + blocks = [rx.DataflowBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + # Error: DataflowVar lv0 is defined outside DataflowBlock + lv0 = rx.DataflowVar("lv0", [m, n], type_anno) + call_node = rx.op.add(x, x) + bindings = [rx.VarBinding(lv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + # Error: DataflowVar lv0 is used outside DataflowBlock + lv0 = rx.DataflowVar("lv0", [m, n], type_anno) + gv0 = rx.Var("gv0", [m, n], type_anno) + call_node = rx.op.add(lv0, x) + bindings = [rx.VarBinding(lv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + +def test_global_var(): + # Error: GlobalVar GlobalVar0 is not defined + gv0 = rx.Var("gv0", [m, n], type_anno) + globalvar = rx.GlobalVar("GlobalVar0") + call_node = rx.Call( + op=tvm.ir.Op.get("relax.call_tir"), + args=[globalvar, rx.Tuple([x]), rx.ShapeExpr([m, n])], + ) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + +def test_symbolic_var(): + # Error: Symbolic Var new_s is not defined + new_s = tir.Var("new_s", "int32") + gv0 = rx.Var("gv0", [m, new_s], type_anno) + call_node = rx.op.add(x, x) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + +def test_symbolic_var_invalid_type(): + # Error: dim must be of integer type, but got float32 + dim = tir.Var("dim", "float32") + type_anno = rx.DynTensorType(ndim=1, dtype="float32") + y = rx.Var("y", [dim], type_anno) + gv0 = rx.Var("gv0", [dim], type_anno) + call_node = rx.op.add(y, y) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks, [y]) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + +def test_seq_expr(): + # Error: SeqExpr in VarBinding + gv0 = rx.Var("gv0", [m, n], type_anno) + # build a SeqExpr + gv1 = rx.Var("gv1", [m, n], type_anno) + call_node = rx.op.add(x, gv0) + _bindings = [rx.VarBinding(gv1, call_node)] + _blocks = [rx.BindingBlock(_bindings)] + _seq_expr = rx.SeqExpr(_blocks, gv1) + # build a Binding with the SeqExpr as value + bindings = [rx.VarBinding(gv0, _seq_expr)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + +def test_if(): + # Error: Var defined in true/false branch is invisible in the outer scope + # except the return Var, i.e the var in the last stmt + # v_in_if is invisible in the outer scope + v_in_if = rx.Var("v_in_if", [m, n], type_anno) + # gv0 is visible in the outer scope + gv0 = rx.Var("gv0", [m, n], type_anno) + # build true branch + true_bindings = [ + rx.VarBinding(v_in_if, rx.op.add(x, x)), + rx.VarBinding(gv0, rx.op.multiply(x, x)), + ] + true_blocks = [rx.BindingBlock(true_bindings)] + true_seq_expr = rx.SeqExpr(true_blocks, true_blocks[-1].bindings[-1].var) + # build false branch + false_bindings = [ + rx.VarBinding(v_in_if, rx.op.multiply(x, x)), + rx.VarBinding(gv0, rx.op.add(x, x)), + ] + false_blocks = [rx.BindingBlock(false_bindings)] + false_seq_expr = rx.SeqExpr(false_blocks, false_blocks[-1].bindings[-1].var) + # build If node + if_node = rx.If(cond=cond, true_branch=true_seq_expr, false_branch=false_seq_expr) + gv1 = rx.Var("gv1", [m, n], type_anno) + # try to call v_in_if defined in the true/false branch + bindings = [rx.VarBinding(gv0, if_node), rx.VarBinding(gv1, v_in_if)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + +def test_ANF(): + # Error: Nested Call + gv0 = rx.Var("gv0", [m, n], type_anno) + call_node = rx.op.add(x, rx.op.add(x, x)) + bindings = [rx.VarBinding(gv0, call_node)] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + # Error: Call Node in Tuple + gv0 = rx.Var("gv0", [m, n], type_anno) + bindings = [rx.VarBinding(gv0, rx.Tuple((x, rx.op.add(x, x))))] + blocks = [rx.BindingBlock(bindings)] + func = build_function(blocks) + mod = tvm.IRModule({rx.GlobalVar("foo"): func}) + assert not rx.analysis.well_formed(mod) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_tuning_api.py b/tests/python/relax/test_tuning_api.py new file mode 100644 index 0000000000..c8734370c4 --- /dev/null +++ b/tests/python/relax/test_tuning_api.py @@ -0,0 +1,784 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +import pytest +import numpy as np +import os.path as osp +import tempfile +from typing import List +from math import isclose + +import tvm +from tvm import ir +from tvm.ir import transform +from tvm.ir.transform import PassContext +from tvm.ir.module import IRModule +from tvm.script import tir as T, relax as R +from tvm import relax +from tvm.relax.expr import Expr, DataflowBlock, Function +from tvm.relax.transform.tuning_api import ( + Choice, + Knob, + Trace, + TuningRecord, + JSONDatabase, + default_generate_candidate, + default_consider_eval_passes, + default_evaluate, + select_best_candidate, + get_trace, +) + + +@tvm.script.ir_module +class TestModule: + @T.prim_func + def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16), "int32"]) -> None: + T.func_attr(({"global_symbol": "addone"})) + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + # Input IRModule. + @R.function + def before(c0: Tensor((16, 16), "int32")): + lv0 = relax.call_tir(addone, (c0,), (16, 16), dtype="int32") + return lv0 + + # Expected IRModule after transformation. + @R.function + def expected(c1: Tensor((16, 16), "int32")): + lv0 = c1 + return c1 + + +def gen_mod(mod, name, binding): + funcs = {} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + + for k, v in mod.functions.items(): + if isinstance(v, tvm.relax.Function): + if k.name_hint == name: + # rename to main. + gv = tvm.ir.GlobalVar("main") + funcs[gv] = tvm.relax.Function(v.params, v.body, v.ret_type, v.ret_shape).with_attr( + "global_symbol", "main" + ) + else: + funcs[k] = v + mod = tvm.IRModule(funcs) + return relax.transform.BindParams("main", binding)(mod) + + +# Setup for simple testing with IRModule. +def setup_test(): + mod = TestModule + assert isinstance(mod, tvm.IRModule) + return gen_mod(mod, "before", {}) + + +# Setup for testing with constant folding. +def setup_test_const_folding(): + mod = TestModule + assert isinstance(mod, tvm.IRModule) + # Test setup. + c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16) + c1_np = c0_np + 1 + before = gen_mod(mod, "before", {"c0": c0_np}) + expected = gen_mod(mod, "expected", {"c1": c1_np}) + + return before, expected + + +# Define a choice by using FoldConstant pass. +@tvm.register_func("testing.apply_fold_constant") +def apply_fold_constant(mod): + return relax.transform.FoldConstant()(mod) + + +@tvm.register_func("testing.add_global_symbol") +def add_global_symbol(mod, func_name, global_symbol): + mod[func_name] = mod[func_name].with_attr("global_symbol", global_symbol) + return mod + + +@tvm.register_func("testing.check_num_functions") +def check_num_funcs(mod, N): + # Explicit type specification is necessary. + # Otherwise, PackedFunc cannot derive the return type correctly. + # e.g., Check failed: type_code_ == kDLInt (8 vs. 0) : expected int but got Object + return bool(len(mod.functions) == N) + + +def test_choice(): + # Test setup. + ( + before, + expected, + ) = setup_test_const_folding() + + # Without any argument, default setting will be used for both transformation and constraint functions. + # default transformation function will return the original IRModule without any change. + choice = Choice( + # - transform_func_key="relax.tuning_api.Choice.default_transform_func" + # - constr_func_key="relax.tuning_api.Choice.default_constr_func") + ) + # Load transformation function from the choice and apply it. + after = choice.apply_transform_func(before) + tvm.ir.assert_structural_equal(after, before) + + choice = Choice("testing.apply_fold_constant") + # Load transformation function from the choice and apply it. + after = choice.apply_transform_func(before) + tvm.ir.assert_structural_equal(after, expected) + + # Create a choice that tags global symbol onto target function. + choice = Choice("testing.add_global_symbol", ["addone", "test-symbol"]) + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The transformation should be applied with Copy-On-Write. + # So, the original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + # Test choice with impossible constraint + choice = Choice( + transform_func_key="testing.add_global_symbol", + transform_func_args=["addone", "test-symbol"], + constr_func_key="testing.check_num_functions", + constr_func_args=[1000], + ) + # Since the constraint is not met, it should return the original function + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "addone" + + # Test choice with the proper constraint + choice = Choice( + transform_func_key="testing.add_global_symbol", + transform_func_args=["addone", "test-symbol"], + constr_func_key="testing.check_num_functions", + constr_func_args=[2], + ) + # Since the constraint is not met, it should return the original function + after = choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + # Test roundtrip. + # Export as JSON. + json_obj = choice.as_json() + # Import JSON. + new_choice = Choice.from_json(json_obj) + # Test imported choice + after = new_choice.apply_transform_func(before) + assert after["addone"].attrs["global_symbol"] == "test-symbol" + # The original module should be unchanged. + assert before["addone"].attrs["global_symbol"] == "addone" + + +def test_knob(): + # Test setup. + before, expected = setup_test_const_folding() + + # Users can define a set of choices with list. + choices = [ + Choice("testing.apply_fold_constant"), + Choice(), + ] + + # Define knob. + knob = Knob("TestKnob", choices) + # Check the sanity of decision space. + assert knob.verify(0) + assert knob.verify(1) + assert not knob.verify(3) + + # Check the sanity of each decision. + after_apply = knob.apply(before, 0) + after_noapply = knob.apply(before, 1) + + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + + # Users can define a set of choices with dict. + choices = { + "apply": Choice("testing.apply_fold_constant"), + "noapply": Choice(), + "apply_with_impossible_constr": Choice( + transform_func_key="testing.apply_fold_constant", + constr_func_key="testing.check_num_functions", + constr_func_args=[1000], + ), + } + # Define knob. + knob = Knob("TestKnob", choices) + assert knob.verify("apply") + assert knob.verify("noapply") + assert knob.verify("apply_with_impossible_constr") + assert not knob.verify("INVLAID") + + after_apply = knob.apply(before, "apply") + after_noapply = knob.apply(before, "noapply") + # Because constr was not satisfied, it will return the original IRModule + after_apply_with_constr = knob.apply(before, "apply_with_impossible_constr") + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + tvm.ir.assert_structural_equal(after_apply_with_constr, before) + + # Test roundtrip. + # Export as JSON. + json_obj = knob.as_json() + # Import JSON. + new_knob = Knob.from_json(json_obj) + assert new_knob.name == knob.name + # Test imported knob + assert new_knob.verify("apply") + assert new_knob.verify("noapply") + assert new_knob.verify("apply_with_impossible_constr") + assert not new_knob.verify("INVLAID") + + after_apply = new_knob.apply(before, "apply") + after_noapply = new_knob.apply(before, "noapply") + # Because constr was not satisfied, it will return the original IRModule + after_apply_with_constr = knob.apply(before, "apply_with_impossible_constr") + tvm.ir.assert_structural_equal(after_apply, expected) + tvm.ir.assert_structural_equal(after_noapply, before) + tvm.ir.assert_structural_equal(after_apply_with_constr, before) + + +def test_trace(): + before, expected = setup_test_const_folding() + + # Define choices and its knob. + choices = { + "apply": Choice( + transform_func_key="testing.apply_fold_constant", + transform_func_args=[], + constr_func_key="testing.check_num_functions", + constr_func_args=[2], + ), + "noapply": Choice(), + } + knob = Knob("TestKnob", choices) + + # Define a Trace with empty decision (transformation) history. + trace = Trace(before) + assert trace.size == 0 + + # Define a Trace with single decision (transformation) history. + trace = Trace(before, [knob], ["noapply"]) + assert trace.size == 1 + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, before) + + # Add a new knob and its decision to the trace. + # It will update the current trace and returns its new output IRModule. + out: IRModule = trace.add(knob, "noapply") + assert trace.size == 2 + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, before) + tvm.ir.assert_structural_equal(out, before) + # Assume we assign arbitrary performance number. + trace.set_perf(100) + assert trace.perf == 100 + + # Add a new knob and its decision to the trace. + out: IRModule = trace.add(knob, "apply") + tvm.ir.assert_structural_equal(trace.in_mod, before) + tvm.ir.assert_structural_equal(trace.out_mod, expected) + tvm.ir.assert_structural_equal(out, expected) + + assert trace.size == 3 + # Should be initalized when new knob is applied. + assert trace.perf == -1 + + # Test roundtrip. + # Export as JSON. + json_obj = trace.as_json() + # Import JSON. + new_trace = Trace.from_json(json_obj) + tvm.ir.assert_structural_equal(trace.in_mod, new_trace.in_mod) + assert str(trace) == str(new_trace) + assert new_trace.size == 3 + tvm.ir.assert_structural_equal(trace.out_mod, new_trace.out_mod) + + +def test_trace_wrapper(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + assert isinstance(Trace(mod), Trace) + assert isinstance(get_trace(mod), Trace) + assert isinstance(get_trace(mod["main"]), Trace) + assert isinstance(get_trace(mod["addone"]), Trace) + + +def create_tmp_database(tmpdir: str) -> JSONDatabase: + path_workload = osp.join(tmpdir, "workloads.json") + path_tuning_record = osp.join(tmpdir, "tuning_records.json") + path_measurement_record = osp.join(tmpdir, "measurement_records.json") + return JSONDatabase(path_workload, path_tuning_record, path_measurement_record) + + +def test_database(): + def equal_measurement_record(a: List[float], b: List[float]): + assert len(a) == len(b) + for i in range(len(a)): + assert isclose(a[i], b[i], rel_tol=1e-5) + + def equal_tuning_record(a: TuningRecord, b: TuningRecord): + assert str(a.trace) == str(b.trace) + equal_measurement_record(a.run_secs, b.run_secs) + + # Test setup. + ( + mod1, + mod2, + ) = setup_test_const_folding() + knob = Knob("test", {"noapply": Choice()}) + trace = Trace(mod1, [knob, knob], ["noapply", "noapply"]) + target = tvm.target.Target("llvm") + + # Test roundtrip + run_secs = [1.0, 0.9, 0.4] + tuning_record = TuningRecord( + trace, + run_secs, + ) + new_tuning_record = TuningRecord.from_json(json_obj=tuning_record.as_json()) + equal_tuning_record(tuning_record, new_tuning_record) + + with tempfile.TemporaryDirectory() as tmpdir: + database = create_tmp_database(tmpdir) + workload1 = database.commit_workload(mod1) + + database.commit_measurement_record(workload1, target, run_secs) + new_run_secs1 = database.get_measurement_record(workload1, target) + equal_measurement_record(run_secs, new_run_secs1) + workload2 = database.commit_workload(mod2) + new_run_secs2 = database.get_measurement_record(workload2, target) + assert len(new_run_secs2) == 0 + + database.commit_tuning_record(workload1, target, tuning_record) + new_tuning_records = database.get_top_k(workload1, target, top_k=1) + assert len(new_tuning_records) == 1 + equal_tuning_record(tuning_record, new_tuning_records[0]) + new_tuning_records = database.get_top_k(workload1, target, top_k=0) + assert len(new_tuning_records) == 0 + + +def test_default_functions(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + + # Define choice, knob, trace. + choices = {"apply": Choice("testing.apply_fold_constant"), "noapply": Choice()} + knob = Knob("TestKnob", choices) + trace = Trace(mod) + + # Launch a pass pipeline in trace mode. + with tempfile.TemporaryDirectory() as tmpdir: + database = create_tmp_database(tmpdir) + with transform.PassContext(trace=trace, tuning_api_database=database): + # Default generation function expands every valid choice. + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + + # Default evaluate function uses MetaSchedule builder/runner. + # Since builder/runner are not provided, local builder/runner will be used. + default_evaluate(candidates, "llvm --num-cores=16") + assert PassContext.current().num_evals == 2 + + # Because these candidates are already evaluated, num_evals stays the same. + default_evaluate(candidates, "llvm --num-cores=16") + assert PassContext.current().num_evals == 2 + + # Test with multiple knobs + candidates = default_generate_candidate([knob, knob], trace) + assert len(candidates) == 4 + + # Launch new pass pipeline in trace mode. + with transform.PassContext(trace=trace, tuning_api_database=database): + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + # Provide tuning pass as an eval pass. + # Note that MockConstFoldingTuningPass() has its own generation function, evaluation function. + # Evaluation would be done in a tornament fashion. + # `default_consider_eval_passes` will convert candidates into the best version by considering eval_passes. + # For example, if we say candidates = [C1, C2] + # `default_consider_eval_passes` will return best form of C1 variant (C11 vs C12) and C2 variant (C21 vs C22) + # that can be generated by eval_passes. + # Assume C11 > C12, C21 < C22, + # new_candidates = [C11, C22] + new_candidates = default_consider_eval_passes( + candidates, [MockConstFoldingTuningPass(eval_passes=[])] + ) + + # len(candidates) == len(new candidates). + assert len(new_candidates) == 2 + # To find the best version of each candidate, it would take 4 evals (C11, C12, C21, C22). + assert PassContext.current().num_evals == 4 + + HeuristicPass = relax.transform.FoldConstant + with transform.PassContext(trace=trace, tuning_api_database=database): + candidates = default_generate_candidate([knob], trace) + assert len(candidates) == 2 + # Provide heuristic pass as an eval pass. + new_candidates = default_consider_eval_passes(candidates, [HeuristicPass()]) + # Since heuristic pass has single decision, it won't need any tornament. + # new_candidates = [C11, C21] + assert len(new_candidates) == 2 + # We only conduct evaluation when its necessary (e.g., choose better candidate in tuning pass). + # Heuristic pass won't conduct any evaluation. + assert PassContext.current().num_evals == 0 + + +# TODO(sunggg): Do we need to serialize pass context as well? +def test_pass_context(): + before, expected = setup_test_const_folding() + HeuristicPass = relax.transform.FoldConstant + # FoldConstant implicitly performs TIR passes (prob for constant evaluation). + # If make_traceable is not provided, the pass infra will make every non-traceable pass traceable by default. + seq = transform.Sequential([HeuristicPass()]) + with transform.PassContext( + trace=Trace(before), + ): + after = seq(before) + tvm.ir.assert_structural_equal(after, expected) + assert PassContext.current().get_trace_stack_size() == 1 + # The exact number of implicit passes might change as TVM develops more passes. + # As of today, this size returns 57. + assert PassContext.current().get_current_trace().size > 1 + + # We can explicitly specify which pass we want to keep track of. + with transform.PassContext(trace=Trace(before), make_traceable=["FoldConstant"]): + after = seq(before) + tvm.ir.assert_structural_equal(after, expected) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Check the functionality of trace stack. + with transform.PassContext(trace=Trace(before)): + assert PassContext.current().get_trace_stack_size() == 1 + PassContext.current().push_trace(Trace(before)) + assert PassContext.current().get_trace_stack_size() == 2 + PassContext.current().pop_trace() + assert PassContext.current().get_trace_stack_size() == 1 + PassContext.current().pop_trace() + assert PassContext.current().get_trace_stack_size() == 0 + + +# Mock evaluation pass for testing. +# Assigns arbitrary performance number to each candidate. +def mock_evaluate(candidates: List[Trace], target_str: str, ctx: PassContext): + num_evals = 0 + # Evaluation + for candidate in candidates: + # If this candidate is already evaluated, skip the measurement. + if candidate.perf != -1: + continue + + num_evals += 1 + # Assign arbitrary performance. + mock_perf = 100 - (ctx.num_evals + num_evals) + candidate.set_perf(mock_perf) + # Update number of evals for testing. + ctx.inc_num_evals(num_evals) + + +# Mock tuning pass that determines whether to apply relax.transform.FoldConstant(). +# Each pass invocation will generate two candidates for the incoming IRModule. +# In relax pass infra, each pass will define its own way of generating candidates and evaluating them without needing to know how other passes generate its candidate and evaluate them. +# This will significantly alleviate the development process since it is known to be HARD problem to consider the interaction with (potentially hundreds of) other passes. +@ir.transform.module_pass(opt_level=0, traceable=True) +class MockConstFoldingTuningPass(transform.Pass): + def __init__( + self, + f_generate_candidate=None, + f_evaluate=mock_evaluate, + eval_passes: List[transform.Pass] = None, + required: List[transform.Pass] = [], + ): + self.f_generate_candidate = ( + f_generate_candidate if f_generate_candidate else default_generate_candidate + ) + self.f_evaluate = f_evaluate if f_evaluate else default_evaluate + self.eval_passes = eval_passes + self.required = required + + def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: + trace = ctx.pop_trace() + + # Create mock choices for testing. + choices = {"apply": Choice("testing.apply_fold_constant"), "noapply": Choice()} + # Tuning pass manages a set of transformation functions registered via knob. + knob = Knob("MockTuningKnob", choices) + + candidates = self.f_generate_candidate([knob], trace, self.eval_passes) + self.f_evaluate(candidates, "llvm", ctx) + best_trace = select_best_candidate(candidates) + + ctx.push_trace(best_trace) + return best_trace.out_mod + + +def test_module_pass(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + # Test setup + c0 = np.arange((16 * 16)).astype("int32").reshape(16, 16) + mod = relax.transform.BindParams("main", {"c0": tvm.nd.array(c0)})(mod) + HeuristicPass = relax.transform.FoldConstant + + # Tuning pass without any eval_pass. + mock_pass = MockConstFoldingTuningPass(eval_passes=[]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Heuristic pass should not affect the number of candidates. + mock_pass = MockConstFoldingTuningPass(eval_passes=[HeuristicPass()]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 2 + + # Joint-optimization will increase the search space in the combinatorial way + mock_pass = MockConstFoldingTuningPass(eval_passes=[MockConstFoldingTuningPass(eval_passes=[])]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 2 + + # Joint-optimization can be nested. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[MockConstFoldingTuningPass(eval_passes=[])]) + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 3 + + # Tuning pass and heuritic passes can be used together. + # Note that heuristic pass won't increate the search space (num_evals). + # It only increases the length of the trace. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + HeuristicPass(), + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[HeuristicPass(), HeuristicPass()]) + ] + ), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * 2 * 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 6 + + # Users can mix-use sequential application and joint-application. + mock_pass = MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = mock_pass(mod) + assert PassContext.current().num_evals == 2 * (2 + 2 + 2) + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 4 + + +def test_sequential(): + mod = setup_test() + assert isinstance(mod, tvm.IRModule) + # Test setup. + c0 = np.arange((16 * 16)).astype("int32").reshape(16, 16) + mod = relax.transform.BindParams("main", {"c0": tvm.nd.array(c0)})(mod) + HeuristicPass = relax.transform.FoldConstant + + # Sequential with a single tuning pass should behave same with a single pass. + seq = transform.Sequential([MockConstFoldingTuningPass(eval_passes=[])]) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 1 + + # Sequential pass should increase search space (num_evals) in additive manner. + seq = transform.Sequential( + [ + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + ] + ) + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + 2 + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 3 + + # Heuristic pass will not increase the search space. Just increase trace length. + seq = transform.Sequential( + [ + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + MockConstFoldingTuningPass(eval_passes=[]), + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == 2 + 2 + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 5 + + # Users can mix-use sequential application and joint-application. + seq = transform.Sequential( + [ + HeuristicPass(), + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass( + eval_passes=[ + MockConstFoldingTuningPass( + eval_passes=[ + HeuristicPass(), + ] + ) + ] + ), + ] + ), + MockConstFoldingTuningPass(eval_passes=[]), + HeuristicPass(), + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=["FoldConstant"]): + _ = seq(mod) + assert PassContext.current().num_evals == (2 * 2 * 2) + 2 + assert PassContext.current().get_trace_stack_size() == 1 + assert PassContext.current().get_current_trace().size == 7 + + +def test_passes_with_mixed_granularities(): + @tvm.script.ir_module + class MockModule: + @R.function + def f1(x: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.multiply(x, x) + gv0 = relax.add(x, x) + relax.output(gv0) + return gv0 + + @R.function + def main(x: Tensor((m, n), "float32"), y: Tensor((m, n), "float32")): + with relax.dataflow(): + lv0 = relax.multiply(x, y) + gv0 = relax.add(lv0, y) + relax.output(gv0) + gv1 = relax.multiply(x, y) + gv2 = relax.add(gv1, y) + return (gv0, gv1, gv2) + + mod = MockModule + assert isinstance(mod, tvm.IRModule) + + # Helper function for tuning + def pass_func( + mod: IRModule, ctx: PassContext, eval_passes: List[transform.Pass] = None + ) -> IRModule: + trace = ctx.pop_trace() + + # Create mock choices for testing + choices = [Choice(), Choice(), Choice()] + # Tuning pass manages a set of transformation functions registered via knob. + knob = Knob("MockTuningKnob", choices) + + candidates = default_generate_candidate([knob], trace, eval_passes) + mock_evaluate(candidates, "llvm", ctx) + best_trace = select_best_candidate(candidates) + + ctx.push_trace(best_trace) + return best_trace.out_mod + + @ir.transform.module_pass(opt_level=0, traceable=True) + def MockModulePass(mod: IRModule, ctx: PassContext) -> IRModule: + # Input granularity == Candidate granularity. + return pass_func(mod, ctx) + + @relax.transform.function_pass(opt_level=0, traceable=True) + def MockFunctionPass(func: Expr, mod: IRModule, ctx: PassContext) -> Function: + # Input granularity > Candidate granularity. + # Start trace with smaller granularity: IRModule->Function. + ctx.push_trace(Trace(IRModule.from_expr(func))) + # Do something. + pass_func(mod, ctx) + # Pop tuned trace and recover the previous trace. + ctx.pop_trace() + return func + + @relax.transform.dataflowblock_pass(opt_level=0, traceable=True) + def MockDataflowBlockPass( + block: DataflowBlock, mod: IRModule, ctx: PassContext + ) -> DataflowBlock: + # TODO(sunggg): figure out how to create IRModule from DataflowBlock + # Provide random binding for now + x = relax.Var("x", [tvm.tir.Var("n", "int64")], relax.DynTensorType(1, "float32")) + seq_expr = relax.SeqExpr([block], x) + ret_type = relax.DynTensorType(-1, "float32") + ret_shape = relax.RuntimeDepShape() + func = relax.Function([x], seq_expr, ret_type, ret_shape) + ctx.push_trace(Trace(IRModule.from_expr(func))) + # Do something + pass_func(mod, ctx) + ctx.pop_trace() + return block + + seq = transform.Sequential( + [ + MockModulePass, + MockFunctionPass, + MockDataflowBlockPass, + ] + ) + + with transform.PassContext(trace=Trace(mod), make_traceable=[]): + _ = seq(mod) + # Trace length and num eval can be different depending on how each function/dataflow block is treated. + assert PassContext.current().get_trace_stack_size() == 1 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py new file mode 100644 index 0000000000..a80f71d62a --- /dev/null +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax, tir +from tvm.script.ir_builder import relax as R +from tvm.script.ir_builder.base import IRBuilder + + +def test_function_simple(): + """ + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + out = R.call_tir("extern_func", x, (128, 128), dtype="float32") + return out + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + R.func_attr({"Primitive": 1}) + x = R.arg("x", R.tensor((128, 128), "float32")) + R.func_ret_type(R.tensor(dtype="float32", ndim=2)) + out = R.emit( + R.call_tir("extern_func", x, (128, 128), dtype="float32"), is_dataflow_var=False + ) + IRBuilder.name("out", out) + R.func_ret_value(out) + func = ir_builder.get() + # create with BlockBuilder + x = relax.Var("x", [128, 128], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,), attrs={"Primitive": 1}): + out = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + bb.emit_func_output(out) + mod = bb.get() + + tvm.ir.assert_structural_equal(func, mod["foo"]) + # check names + assert func.attrs["global_symbol"] == "foo" + assert func.params[0].name_hint == "x" + assert func.body.body.name_hint == "out" + + +def test_match_shape(): + """ + @R.function + def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): + m = T.var("int64") + n = T.var("int64") + R.match_shape(x, (m,)) + y1 = R.match_shape(x, (n,)) + return (m, n * 2) + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + x = R.arg("x", R.tensor(ndim=-1, dtype="float32")) + y = R.arg("y", R.tensor(ndim=-1, dtype="float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + R.emit_match_shape(x, (m,), emit_var=False, is_dataflow_var=False) + y1 = R.emit_match_shape(y, (n,), emit_var=True, is_dataflow_var=False) + IRBuilder.name("y1", y1) + R.func_ret_value(relax.ShapeExpr([m, n * 2])) + func = ir_builder.get() + + # create with BlockBuilder + x = relax.Var("x", type_annotation=relax.DynTensorType(-1, "float32")) + y = relax.Var("y", type_annotation=relax.DynTensorType(-1, "float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + bb.match_shape_binding(relax.MatchShape(x, (m,), var=None)) + y1 = bb.match_shape(y, (n,)) + bb.emit_func_output(relax.ShapeExpr([m, n * 2])) + mod = bb.get() + + tvm.ir.assert_structural_equal(func, mod["foo"]) + + +def test_dataflow_block(): + """ + @R.function + def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): + # block 0 + with R.dataflow(): + lv0 = R.call_tir("extern_func", (x,), (128, 128), dtype="float32") + gv: Tensor((128, 128), "float32") = lv0 + R.output(gv) + return gv + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + x = R.arg("x", R.tensor((128, 128), "float32")) + with R.dataflow(): + lv0 = R.emit( + R.call_tir("extern_func", x, (128, 128), dtype="float32"), is_dataflow_var=True + ) + IRBuilder.name("lv0", lv0) + gv = R.emit(lv0, is_dataflow_var=False) + IRBuilder.name("gv", gv) + R.output(gv) + R.func_ret_value(gv) + func = ir_builder.get() + + # create with BlockBuilder + x = relax.Var("x", (128, 128), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + with bb.dataflow(): + lv0 = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + tvm.ir.assert_structural_equal(func, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py new file mode 100644 index 0000000000..bdd7c8c08b --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser.py @@ -0,0 +1,638 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Union + +import pytest +import tvm +import tvm.testing +from tvm import IRModule, relax, tir +from tvm.script._parser import ir as I +from tvm.script._parser import relax as R +from tvm.script._parser import tir as T + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Union[relax.Function, IRModule], +): + # TODO(siyuan): add round-trip tests + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_simple_func(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + R.func_attr({"Primitive": 1}) + gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + return gv0 + + x = relax.Var("x", [128, 128], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,), attrs={"Primitive": 1}): + out = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + bb.emit_func_output(out) + + _check(foo, bb.get()["foo"]) + + +def test_error_report(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = gv1 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + return gv0 + + +def test_simple_module(): + @I.ir_module + class TestModule: + @T.prim_func + def tir_func(x: T.Buffer((128, 128), "float32"), y: T.Buffer((128, 128), "float32")): + T.func_attr({"global_symbol": "tir_func", "tir.noalias": True}) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + y[vi, vj] = x[vi, vj] + 1.0 + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + # TODO(Siyuan): Need to change to `TestModule.tir_func` + gv0 = R.call_tir(tir_func, x, (128, 128), dtype="float32") + return gv0 + + x = relax.Var("x", [128, 128], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") + bb.emit_func_output(out) + + _check(TestModule, bb.get()) + + +def test_relax_tensor_op(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor(None, "float32", ndim=2): + y = R.add(x, x) + z = R.multiply(x, y) + return z + + x = relax.Var("x", [4, 4], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.op.add(x, x)) + z = bb.emit(relax.op.multiply(x, y)) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_relax_base_op(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + alloc = R.builtin.alloc_tensor((4, 4), runtime_device_index=0, dtype="float32") + shape = R.shape_of(alloc) + return shape + + x = relax.Var("x", [4, 4], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), "float32", 0)) + shape = bb.emit(relax.op.shape_of(alloc)) + bb.emit_func_output(shape) + + _check(foo, bb.get()["foo"]) + + +def test_symbolic_shape(): + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): + m = T.var("int64", "m") + n = T.var("int64", "n") + gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32") + return gv0 + + @R.function + def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): + m = T.var("int64") + n = T.var("int64") + gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32") + return gv0 + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): + m = T.var("int64") + n = T.var("int32") # The shape dtype should be int64 + gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32") + return gv0 + + def _expected(name: str): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = relax.Var("x", [m, n], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function(name, (x,)): + out = bb.emit(relax.call_tir("extern_func", x, (m, n), dtype="float32")) + bb.emit_func_output(out) + return bb.get()[name] + + _check(foo, _expected("foo")) + _check(bar, _expected("bar")) + + +def test_shadowing(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + y = R.add(x, x) + z = R.multiply(x, y) + y = R.add(x, y) + y = z + y = R.multiply(y, x) + z = y + return z + + x = relax.Var("x", [4, 4], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.op.add(x, x)) + z = bb.emit(relax.op.multiply(x, y)) + y = bb.emit(relax.op.add(x, y)) + y = bb.emit(z) + y = bb.emit(relax.op.multiply(y, x)) + z = bb.emit(y) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_match_shape(): + @R.function + def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): + m = T.var("int64") + n = T.var("int64") + R.match_shape(x, (m,)) + y1 = R.match_shape(y, (n,)) + return (m, n * 2) + + x = relax.Var("x", type_annotation=relax.DynTensorType(-1, "float32")) + y = relax.Var("y", type_annotation=relax.DynTensorType(-1, "float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + bb.match_shape_binding(relax.MatchShape(x, (m,), var=None)) + y1 = bb.match_shape(y, (n,)) + bb.emit_func_output(relax.ShapeExpr([m, n * 2])) + _check(foo, bb.get()["foo"]) + + +def test_tuple_return(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + gv0 = R.call_tir("extern_func_0", x, (4, 4), dtype="float32") + gv1 = R.call_tir("extern_func_1", x, (4, 4), dtype="float32") + return (gv0, gv1) + + x = relax.Var("x", [4, 4], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + gv0 = bb.emit(relax.call_tir("extern_func_0", x, (4, 4), dtype="float32")) + gv1 = bb.emit(relax.call_tir("extern_func_1", x, (4, 4), dtype="float32")) + bb.emit_func_output(relax.Tuple((gv0, gv1))) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_block(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + lv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + lv1 = R.call_tir("extern_func", lv0, (128, 128), dtype="float32") + gv = lv1 + R.output(gv) + return gv + + x = relax.Var("x", (128, 128), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + with bb.dataflow(): + lv0 = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + lv1 = bb.emit(relax.call_tir("extern_func", lv0, (128, 128), dtype="float32")) + gv = bb.emit_output(lv1) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_block_advanced(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + gv1 = R.call_tir("extern_func", gv0, (128, 128), dtype="float32") + with R.dataflow(): + m = T.var("int64") + n = T.var("int64") + lv0 = R.call_tir("extern_func", gv1, (128, 128), dtype="float32") + lv1 = R.match_shape(lv0, (m, n)) + gv2 = R.call_tir("extern_func", lv0, (128, 128), dtype="float32") + gv2 = R.call_tir("extern_func", gv2, (128, 128), dtype="float32") + gv3 = R.match_shape(gv2, (m, n)) + gv3 = R.match_shape(lv0, (m, n)) + gv4 = gv3 + gv5 = gv2 + R.output(gv5, gv4) + gv6 = R.call_tir("extern_func", gv5, (128, 128), dtype="float32") + gv7 = R.call_tir("extern_func", gv6, (128, 128), dtype="float32") + return gv7 + + x = relax.Var("x", (128, 128), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + with bb.function("foo", (x,)): + gv0 = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + gv1 = bb.emit(relax.call_tir("extern_func", gv0, (128, 128), dtype="float32")) + with bb.dataflow(): + lv0 = bb.emit(relax.call_tir("extern_func", gv1, (128, 128), dtype="float32")) + lv1 = bb.match_shape(lv0, (m, n)) + gv2 = bb.emit(relax.call_tir("extern_func", lv0, (128, 128), dtype="float32")) + gv21 = bb.emit(relax.call_tir("extern_func", gv2, (128, 128), dtype="float32")) + gv3 = bb.match_shape(gv21, (m, n)) + gv31 = bb.match_shape(lv0, (m, n)) + gv32 = bb.emit_output(gv31) + gv22 = bb.emit_output(gv21) + gv4 = bb.emit(relax.call_tir("extern_func", gv22, (128, 128), dtype="float32")) + gv5 = bb.emit(relax.call_tir("extern_func", gv4, (128, 128), dtype="float32")) + bb.emit_func_output(gv5) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_binding_after_output(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv = R.call_tir("extern_func", x, (128, 128), dtype="float32") + R.output(gv) + lv = R.call_tir("extern_func", gv, (128, 128), dtype="float32") + return gv + + +def test_dataflow_output_global_var(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + with R.dataflow(): + gv1 = R.call_tir("extern_func", gv0, (128, 128), dtype="float32") + R.output(gv0, gv1) + return gv1 + + +def test_dataflow_multiple_output(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv = R.call_tir("extern_func", x, (128, 128), dtype="float32") + R.output(gv) + R.output(gv) + return gv + + +def test_dataflow_output_outside_dataflow_block(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir("extern_func", x, (128, 128), dtype="float32") + R.output(gv) + return gv + + +def test_return_without_binding(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + return x + + x = relax.Var("x", (128, 128), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + bb.emit_func_output(x) + + _check(foo, bb.get()["foo"]) + + +def test_multiple_return(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + return x + return x + + +def test_function_without_return(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + + +def test_tensor_type_without_args(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + v = R.call_tir("tir_relu", x, (32, 32), dtype="float32") + return v + + x = relax.Var("x", (32, 32), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + v = bb.emit(relax.call_tir("tir_relu", x, (32, 32), dtype="float32")) + bb.emit_func_output(v) + + _check(foo, bb.get()["foo"]) + + +def test_direct_return(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + return x + + x = relax.Var("x", (32, 32), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + bb.emit_func_output(x) + + _check(foo, bb.get()["foo"]) + + +def test_call_packed(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + z = R.call_packed("vm.builtin.copy", x, type_args=R.Tensor((32, 32), "float32")) + return z + + x = relax.Var("x", (32, 32), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + z = bb.emit( + relax.Call( + relax.ExternFunc("vm.builtin.copy"), + (x,), + None, + type_args=[relax.DynTensorType(2, "float32")], + ) + ) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_annotation(): + @R.function + def foo( + x: R.Tensor((32, "m"), "float32"), + y: R.Tensor(("m"), "float32"), + r: R.Tensor(dtype="int64"), + ) -> R.Object: + m = T.var("int64") + z: R.Tensor((32, m), "float32") = R.multiply(x, y) + w: R.Tensor = R.multiply(z, z) + q: R.Tensor(ndim=2) = R.add(w, w) + t = R.add(w, z) + sh: R.Shape = R.shape_of(t) + o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, type_args=R.Object) + return o + + def _check_type_shape(binding, expected_type, expected_shape): + tvm.ir.assert_structural_equal(binding.var.checked_type, expected_type) + tvm.ir.assert_structural_equal(binding.var.shape_, expected_shape) + + # Cannot use block builder here because we need to check the annotated type, + # which may be inconsistent with deduced type. + assert isinstance(foo.ret_type, relax.ObjectType) + m = foo.params[0].shape[1] + bindings = foo.body.blocks[0].bindings + _check_type_shape( + bindings[0], relax.DynTensorType(ndim=2, dtype="float32"), relax.ShapeExpr([32, m]) + ) + _check_type_shape(bindings[1], relax.DynTensorType(dtype=""), None) + _check_type_shape(bindings[2], relax.DynTensorType(ndim=2, dtype=""), None) + _check_type_shape(bindings[3], relax.DynTensorType(dtype=""), None) + _check_type_shape(bindings[4], relax.ShapeType(), None) + _check_type_shape(bindings[5], relax.ObjectType(), None) + + +def test_annotate_override(): + @R.function + def foo(x: R.Tensor): + y = x + # z will be treated as object type even though it's a tensor + z: R.Object = y + return z + + assert isinstance(foo.ret_type, relax.ObjectType) + y_bind, z_bind = foo.body.blocks[0].bindings + assert isinstance(y_bind.var.checked_type, relax.DynTensorType) + assert isinstance(z_bind.var.checked_type, relax.ObjectType) + + +def test_empty_shape(): + @R.function + def foo(x: R.Tensor((), "float32")): + z = R.call_tir("scalar_add", x, (), dtype="float32") + return z + + (z_bind,) = foo.body.blocks[0].bindings + shape_expr = z_bind.value.args[2] + + assert isinstance(shape_expr, relax.ShapeExpr) + assert len(shape_expr.values) == 0 + + +def test_local_function(): + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + @R.function + def outer_func( + c1: R.Tensor((2, 3), "float32") + ) -> R.Callable((R.Tensor(None, "float32", ndim=2),), R.Tensor(None, "float32", ndim=2)): + @R.function + def inner_func(x1: R.Tensor((2, 3), "float32")): + s: R.Tensor((2, 3), "float32") = R.add(x1, c1) + return s + + return inner_func + + in_call = outer_func(x) + res = in_call(y) + return res + + main_bindings = main.body.blocks[0].bindings + assert len(main_bindings) == 3 + outer_func = main_bindings[0].value + assert isinstance(outer_func, relax.Function) + + outer_func_bindings = outer_func.body.blocks[0].bindings + assert len(outer_func_bindings) == 1 + inner_func = outer_func_bindings[0].value + assert isinstance(inner_func, relax.Function) + + @I.ir_module + class TestModule: + @R.function + def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), "float32")): + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + + z = relax.call_tir(my_matmul, (x, y), (128, 128), dtype="float32") + return z + + bindings = TestModule["f"].body.blocks[0].bindings + assert len(bindings) == 2 + tir_func = bindings[0].value + assert isinstance(tir_func, tir.PrimFunc) + + +def test_if_branch(): + @R.function + def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")) -> R.Tensor((1,), "float32"): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return y + + cond, x = foo.params + y_bind = foo.body.blocks[0].bindings[0] + y, ite = y_bind.var, y_bind.value + + assert isinstance(y, relax.Var) + assert y.name_hint == "y" + + assert isinstance(ite, relax.If) + assert isinstance(ite.true_branch, relax.SeqExpr) + assert isinstance(ite.false_branch, relax.SeqExpr) + + def check_call(call, op, args): + assert isinstance(call, relax.Call) + if isinstance(op, str): + assert str(call.op) == op + else: + assert call.op == op + tvm.ir.assert_structural_equal(call.args, args) + + w_bind = ite.true_branch.blocks[0].bindings[0] + body = ite.true_branch.body + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "relax.add", [x, x]) + check_call(body, "relax.multiply", [w_bind.var, w_bind.var]) + + w_bind = ite.false_branch.blocks[0].bindings[0] + body = ite.false_branch.body + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "relax.multiply", [x, x]) + check_call(body, "relax.add", [w_bind.var, w_bind.var]) + + +def test_if_inside_dataflow(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor((1,), "float32"): + with R.dataflow(): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + R.output(y) + return y + + +def test_if_branch_output_name(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor((1,), "float32"): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + z = R.add(w, w) + return y + + +def test_if_branch_var_scope(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor((1,), "float32"): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return w + + +def test_other_cases(): + # They are corner case tests, which is only to check if it can be parsed. + # No need to add structural equal checks here + @R.function + def foo(x: R.Tensor): + return R.unique(x, sorted=True) + + @R.function + def bar(x: R.Tensor): + return R.print(x, format="{}") + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_type.py b/tests/python/relax/test_type.py new file mode 100644 index 0000000000..9c6647f136 --- /dev/null +++ b/tests/python/relax/test_type.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np +import tvm +from tvm import relax as rx +from tvm.relax.ty import is_base_of + + +def test_shape_type(): + t0 = rx.ShapeType() + t1 = rx.ShapeType() + assert t0 == t1 + + +def test_dyn_tensor_type(): + t0 = rx.DynTensorType() + assert t0.ndim == -1 + t1 = rx.DynTensorType(3, "int32") + assert t1.ndim == 3 + assert t1.dtype == "int32" + + +def test_subtype(): + # check the subtype relation for DynTensorType + # e.g., DynTensorType(ndim=3, dtype="float32") is a subtype of DynTensorType(ndim=-1, dtype="float32") + # and DynTensorType(ndim=-1, "float32") is a subtype of DynTensorType(ndim=-1, dtype=None) + t0 = rx.DynTensorType(-1, None) + t1 = rx.DynTensorType(3, None) + t2 = rx.DynTensorType(3, "int32") + t3 = rx.DynTensorType(3, "float32") + t4 = rx.DynTensorType(3, "float32") + assert is_base_of(t0, t1) + assert is_base_of(t0, t2) + assert is_base_of(t0, t3) + assert is_base_of(t0, t4) + assert is_base_of(t1, t2) + assert is_base_of(t1, t3) + assert is_base_of(t4, t3) + # a type is subtype of itself + assert is_base_of(t2, t2) + assert is_base_of(t3, t3) + assert is_base_of(t2, t3) == False + assert is_base_of(t3, t2) == False + + # check the subtype relation for ShapeType + t5 = rx.ShapeType() + t6 = rx.ShapeType() + assert is_base_of(t5, t5) + assert is_base_of(t5, t6) + assert is_base_of(t5, t0) == False + + # check the subtype relation for TupleType by checking if each field + # of the base TupleType is subtype of the field of the derived TupleType + # e.g., TupleType([DynTensorType(ndim=3, dtype="float32"), ShapeType()]) + # is a subtype of TupleType([DynTensorType(ndim=-1, dtype="float32"), ShapeType()]) + t7 = rx.TupleType([t0, t1, t5]) + t8 = rx.TupleType([t1, t1, t5]) + t9 = rx.TupleType([t1, t3, t5]) + t10 = rx.TupleType([t5, t3, t1]) + t11 = rx.TupleType([t1, t3]) + assert is_base_of(t7, t8) + assert is_base_of(t7, t9) + assert is_base_of(t8, t9) + assert is_base_of(t8, t8) + assert is_base_of(t9, t7) == False + assert is_base_of(t7, t10) == False + assert is_base_of(t11, t7) == False + assert is_base_of(t7, t11) == False + + # check the subtype relation for FunctionType by checking the subtype relations of arg_types and ret_type + # e.g., FuncType([DynTensorType(ndim=3, dtype="float32")], DynTensorType(ndim=2, dtype="float32")) + # is a subtype of FuncType([DynTensorType(ndim=-1, dtype=None)], DynTensorType(ndim=-1, dtype="float32")) + t12 = rx.FuncType([t7], t0) + t13 = rx.FuncType([t7], t1) + t14 = rx.FuncType([t8], t0) + t15 = rx.FuncType([t8], t1) + t16 = rx.FuncType([t7, t0], t1) + t17 = rx.FuncType([t7, t4], t1) + assert is_base_of(t12, t13) + assert is_base_of(t12, t14) + assert is_base_of(t12, t15) + assert is_base_of(t12, t12) + assert is_base_of(t13, t14) == False + assert is_base_of(t13, t15) + assert is_base_of(t14, t15) + assert is_base_of(t16, t17) + assert is_base_of(t16, t16) + assert is_base_of(t12, t16) == False + assert is_base_of(t13, t16) == False + + # check the subtype relation for ObjectType + # ObjectType is the base type of every type in Relax + t18 = rx.ObjectType() + assert is_base_of(t18, t0) + assert is_base_of(t18, t5) + assert is_base_of(t18, t7) + assert is_base_of(t18, t12) + assert is_base_of(t18, t18) + assert is_base_of(t0, t18) == False + assert is_base_of(t5, t18) == False + assert is_base_of(t7, t18) == False + assert is_base_of(t12, t18) == False + + # more complicated cases + # TupleType with all possible types as fields + t19 = rx.TupleType([t7, t0, t5, t12, t18]) + t20 = rx.TupleType([t8, t1, t5, t15, t18]) + t21 = rx.TupleType([t18, t18, t18, t18, t18]) + assert is_base_of(t19, t20) + assert is_base_of(t21, t19) + assert is_base_of(t21, t20) + assert is_base_of(t20, t19) == False + assert is_base_of(t18, t20) + # FuncType with all possible types as arg_types and ret_type + t22 = rx.FuncType([t7, t0, t5, t12, t18], t0) + t23 = rx.FuncType([t8, t1, t5, t15, t18], t1) + t24 = rx.FuncType([t7], t0) + t25 = rx.FuncType([t18, t18, t18, t18, t18], t18) + t26 = rx.FuncType([t18], t18) + t27 = rx.FuncType([t7, t0, t5, t12, t18], t19) + t28 = rx.FuncType([t7, t0, t5, t12, t18], t20) + assert is_base_of(t22, t23) + assert is_base_of(t25, t23) + assert is_base_of(t18, t23) + assert is_base_of(t18, t22) + assert is_base_of(t27, t28) + assert is_base_of(t24, t22) == False + assert is_base_of(t24, t23) == False + assert is_base_of(t26, t23) == False + assert is_base_of(t28, t27) == False + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py new file mode 100644 index 0000000000..292d1ebdfb --- /dev/null +++ b/tests/python/relax/test_vm.py @@ -0,0 +1,1262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations # must import to defer parsing of annotations +import os +from typing import Any, Callable, List, Tuple + +import sys +import tempfile +import numpy as np +import pytest +import tvm +from tvm.runtime.object import Object +import tvm.script +import tvm.testing +from tvm import relax, rpc, te, tir, topi, TVMError +from tvm.contrib import utils +from tvm.relax.testing import nn +from tvm.script import relax as R, tir as T + + +@tvm.register_func("test.vm.move") +def move(src): + return src + + +@tvm.register_func("test.vm.add") +def add(a, b): + ret = a.numpy() + b.numpy() + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.mul") +def mul(a, b): + ret = a.numpy() * b.numpy() + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.equal_zero") +def equal_zero(a): + ret = np.all((a.numpy() == 0)) + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.subtract_one") +def subtract_one(a): + ret = np.subtract(a.numpy(), 1) + return tvm.nd.array(ret) + + +@tvm.register_func("test.vm.identity") +def identity_packed(a, b): + b[:] = tvm.nd.array(a.numpy()) + + +@tvm.register_func("test.vm.tile") +def tile_packed(a, b): + b[:] = tvm.nd.array(np.tile(a.numpy(), (1, 2))) + + +def check_saved_func(vm: relax.VirtualMachine, func_name: str, *inputs: List[Any]) -> Object: + # uses save_function to create a closure with the given inputs + # and ensure the result is the same + # (assumes the functions return tensors and that they're idempotent) + saved_name = f"{func_name}_saved" + vm.save_function(func_name, saved_name, *inputs) + res1 = vm[func_name](*inputs) + res2 = vm[saved_name]() + tvm.testing.assert_allclose(res1.numpy(), res2.numpy(), rtol=1e-7, atol=1e-7) + return res1 + + +def test_vm_execute(): + ib = relax.ExecBuilder() + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + add_res = check_saved_func(vm, "func0", a, b) + tvm.testing.assert_allclose(add_res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_multiple_func(): + ib = relax.ExecBuilder() + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + with ib.function("func1", num_inputs=2): + ib.emit_call("test.vm.mul", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + mul_res = check_saved_func(vm, "func1", a, b) + add_res = check_saved_func(vm, "func0", a, b) + tvm.testing.assert_allclose(add_res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(mul_res.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_exec_serialize_export_library(): + @tvm.script.ir_module + class TestVMMove: + @R.function + def foo(x: Tensor((3, 4), "float32")): + z = R.call_packed("vm.builtin.copy", x, type_args=(Tensor(ndim=2, dtype="float32"))) + return z + + mod = TestVMMove + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + + from tvm.contrib import utils + + temp_dir = utils.tempdir() + path_exec = temp_dir.relpath("exec.so") + ex.mod.export_library(path_exec) + + loaded_exec = relax.vm.Executable(tvm.runtime.load_module(path_exec)) + assert ex.as_text() == loaded_exec.as_text() + + +def test_vm_checker(): + ib = relax.ExecBuilder() + with pytest.raises(TVMError): + with ib.function("func0", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(2)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ib.get() + + +def test_vm_formalize(): + ib0 = relax.ExecBuilder() + ib1 = relax.ExecBuilder() + with ib0.function("func0", num_inputs=2): + ib0.emit_call("test.vm.add", args=[ib0.r(0), ib0.r(1)], dst=ib0.r(100)) + ib0.emit_call("test.vm.mul", args=[ib0.r(1), ib0.r(100)], dst=ib0.r(50)) + ib0.emit_ret(ib0.r(50)) + with ib1.function("func0", num_inputs=2): + ib1.emit_call("test.vm.add", args=[ib1.r(0), ib1.r(1)], dst=ib1.r(2)) + ib1.emit_call("test.vm.mul", args=[ib1.r(1), ib1.r(2)], dst=ib1.r(3)) + ib1.emit_ret(ib1.r(3)) + exec0 = ib0.get() + exec1 = ib1.get() + assert exec0.as_text() == exec1.as_text() + + +@tvm.register_func("test.vm.add_scalar") +def add_scalar(a, b): + return a + b + + +@tvm.register_func("test.vm.get_device_id") +def get_device_id(device): + return device.device_id + + +def test_vm_operand(): + ib0 = relax.ExecBuilder() + with ib0.function("func0", num_inputs=2): + ib0.emit_call("test.vm.add_scalar", args=[ib0.r(0), ib0.r(1)], dst=ib0.r(2)) + ib0.emit_ret(ib0.r(2)) + exec0 = ib0.get() + vm = relax.VirtualMachine(exec0, tvm.cpu()) + res = vm["func0"](2, 3) + assert res == 5 + + ib1 = relax.ExecBuilder() + with ib1.function("func1", num_inputs=1): + ib1.emit_call("test.vm.get_device_id", args=[ib1.r(0)], dst=ib1.r(1)) + ib1.emit_ret(ib1.r(1)) + exec1 = ib1.get() + vm = relax.VirtualMachine(exec1, tvm.cpu()) + res = vm["func1"](tvm.cpu(3)) + assert res == 3 + + +def test_vm_shapeof(): + ib = relax.ExecBuilder() + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape)) + with ib.function("main", num_inputs=0): + ib.emit_call("vm.builtin.shape_of", args=[arr], dst=ib.r(0)) + ib.emit_ret(ib.r(0)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + for i, s in enumerate(res): + assert s == shape[i] + + +def test_vm_storage(): + dtype = tvm.DataType("float32") + shape = (4, 6) + ib = relax.ExecBuilder() + with ib.function("main", num_inputs=0): + ib.emit_call( + "vm.builtin.alloc_storage", args=[ib.vm_state(), (24,), ib.imm(0), dtype], dst=ib.r(1) + ) + ib.emit_call( + "vm.builtin.alloc_tensor", args=[ib.r(1), ib.imm(0), shape, dtype], dst=ib.r(2) + ) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + assert res.device == tvm.cpu() + assert res.shape == shape + + +def test_vm_copy(): + @tvm.script.ir_module + class TestVMMove: + @R.function + def foo(x: Tensor((3, 4), "float32")): + z = R.call_packed("vm.builtin.copy", x, type_args=(Tensor(ndim=2, dtype="float32"))) + return z + + mod = TestVMMove + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = check_saved_func(vm, "foo", inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_goto(): + ib = relax.ExecBuilder() + with ib.function("main", num_inputs=2): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_goto(2) + ib.emit_call("test.vm.mul", args=[ib.r(2), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + res = check_saved_func(vm, "main", a, b) + tvm.testing.assert_allclose(res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_if(): + ib = relax.ExecBuilder() + with ib.function("main", num_inputs=3): + ib.emit_if(ib.r(0), 3) + ib.emit_call("test.vm.add", args=[ib.r(1), ib.r(2)], dst=ib.r(3)) + ib.emit_goto(2) + ib.emit_call("test.vm.mul", args=[ib.r(1), ib.r(2)], dst=ib.r(3)) + ib.emit_ret(ib.r(3)) + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + a = tvm.nd.array( + np.random.rand( + 4, + ) + ) + b = tvm.nd.array( + np.random.rand( + 4, + ) + ) + res = vm["main"](tvm.nd.array(False), a, b) + tvm.testing.assert_allclose(res.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + res = vm["main"](tvm.nd.array(1), a, b) + tvm.testing.assert_allclose(res.numpy(), a.numpy() + b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_compile_if(): + @tvm.script.ir_module + class TestVMCompileIf: + @R.function + def ife(cond: Tensor((), "bool"), x: Tensor((3, 4), "float32")) -> Tensor: + if cond: + w = relax.call_packed("test.vm.add", x, x, type_args=(Tensor)) + else: + w = relax.call_packed("test.vm.mul", x, x, type_args=(Tensor)) + return w + + mod = TestVMCompileIf + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(3, 4)) + res = vm["ife"](tvm.nd.array(1), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(True), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() + inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(0), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7) + res = vm["ife"](tvm.nd.array(False), inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy() * inp.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_compile_stage0(): + @tvm.script.ir_module + class TestVMCompileStage0: + @R.function + def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): + z = R.call_packed("test.vm.identity", x, y, type_args=(Tensor(ndim=2, dtype="float32"))) + return y + + mod = TestVMCompileStage0 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm["foo"](inp1, inp2) + tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_compile_stage1(): + @tvm.script.ir_module + class TestVMCompileStage1: + @T.prim_func + def shape_func0(heap: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "shape_func0"}) + H = T.match_buffer( + heap, + [T.int64(4)], + dtype="int64", + elem_offset=T.int64(0), + align=128, + offset_factor=1, + ) + # body + H[2] = H[0] * T.int64(2) + H[3] = H[1] * T.int64(3) + + @R.function + def foo(x: Tensor(_, "float32")): + shape_heap: Tensor((4,), "int64") = relax.call_packed( + "vm.builtin.alloc_shape_heap", (4,), type_args=(Tensor(ndim=1, dtype="int64")) + ) + gv0 = relax.call_packed("vm.builtin.shape_of", x, type_args=(Shape)) + gv1 = relax.call_packed( + "vm.builtin.store_shape", gv0, shape_heap, (0, 1), type_args=(Void) + ) + gv2 = shape_func0(shape_heap) + gv3 = relax.call_packed("vm.builtin.load_shape", shape_heap, (2, 3), type_args=(Shape)) + return gv3 + + mod = TestVMCompileStage1 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape)) + res = vm["foo"](arr) + assert res[0] == shape[0] * 2 + assert res[1] == shape[1] * 3 + + +def test_vm_compile_stage2(): + @tvm.script.ir_module + class TestVMCompileStage2: + @R.function + def foo(x: Tensor(_, "float32")) -> Shape: + R.match_shape(x, (n, m)) + return (n * 2, m * 3) + + mod = TestVMCompileStage2 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape)) + res = vm["foo"](arr) + assert res[0] == shape[0] * 2 + assert res[1] == shape[1] * 3 + + +def test_vm_compile_stage3(): + @tvm.script.ir_module + class TestVMCompileStage3: + @R.function + def foo(x: Tensor((32, 16), "float32")) -> Tensor: + with R.dataflow(): + y = R.call_tir("test.vm.identity", (x), (32, 16), dtype="float32") + R.output(y) + return y + + mod = TestVMCompileStage3 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = vm["foo"](inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_compile_e2e(): + @tvm.script.ir_module + class TestVMCompileE2E: + @R.function + def foo(x: Tensor(_, "float32")) -> Tensor: + with R.dataflow(): + R.match_shape(x, (n, m)) + y = R.call_tir("test.vm.tile", (x), (n, m * 2), dtype="float32") + R.output(y) + return y + + mod = TestVMCompileE2E + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = check_saved_func(vm, "foo", inp) + tvm.testing.assert_allclose(res.numpy(), np.tile(inp.numpy(), (1, 2)), rtol=1e-7, atol=1e-7) + + +def test_vm_compile_e2e_func_param_with_shape(): + @tvm.script.ir_module + class TestVMCompileE2E2: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def func(x: Tensor((m, n), "float32"), w: Tensor((n, k), "float32")) -> Tensor: + gv0 = R.call_tir(tir_matmul, (x, w), (m, k), dtype="float32") + return gv0 + + mod = TestVMCompileE2E2 + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + res = check_saved_func(vm, "func", data, weight) + expected = np.dot(data.numpy(), weight.numpy()) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +def test_vm_emit_te_extern(): + if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): + print("skip because extern function is not available") + return + bb = relax.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = relax.DynTensorType(2, "float32") + x = relax.Var("x", [n, m], type_anno) + y = relax.Var("y", [m, n], type_anno) + + with bb.function("rx_cblas_matmul", [x, y]): + out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) + bb.emit_func_output(out) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + res = check_saved_func(vm, "rx_cblas_matmul", data, weight) + expected = np.dot(data.numpy(), weight.numpy()) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +def test_vm_emit_te_concat(): + # concatenate of two vectors of size (n,) and (m,) + bb = relax.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + type_anno = relax.DynTensorType(1, "float32") + x = relax.Var("x", [n], type_anno) + y = relax.Var("y", [m], type_anno) + + def te_func(A, B): + C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i], B[i - n])) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x, y) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array( + np.random.rand( + 1, + ).astype(np.float32) + ) + inp2 = tvm.nd.array( + np.random.rand( + 2, + ).astype(np.float32) + ) + res = check_saved_func(vm, "rx_func", inp, inp2) + tvm.testing.assert_allclose( + res.numpy(), np.append(inp.numpy(), inp2.numpy()), rtol=1e-7, atol=1e-7 + ) + + +def test_vm_emit_te_dtype_change(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + type_anno = relax.DynTensorType(1, "float32") + x = relax.Var("x", [n], type_anno) + + # convert a tensor with dtype of float32 to int16 + def te_func(A): + B = te.compute((n,), lambda i: A[i].astype("int16")) + return B + + with bb.function("rx_func", [x]): + y = bb.emit_te(te_func, x) + bb.emit_func_output(y) + + mod = bb.get() + + new_mod = relax.transform.CallTIRRewrite()(mod) + assert new_mod["rx_func"].body.blocks[0].bindings[0].value.attrs.dtype == "int16" + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array( + np.random.rand( + 1, + ).astype(np.float32) + ) + res = check_saved_func(vm, "rx_func", inp) + np.testing.assert_allclose(res.numpy(), inp.numpy().astype("int16")) + + +def test_vm_emit_te_floor_symbolic_shape(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + type_anno = relax.DynTensorType(1, "float32") + x = relax.Var("x", [n], type_anno) + + def te_func(A): + C = te.compute((tir.floordiv(n, 2),), lambda i: A[i] + 1) + return C + + with bb.function("rx_func", [x]): + x1 = bb.emit_te(te_func, x) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape = (9,) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = check_saved_func(vm, "rx_func", inp) + + def expected_output(): + output_shape = (shape[0] // 2,) + return inp.numpy()[: output_shape[0]] + 1 + + tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) + + +def test_vm_emit_te_constant_param_cpu(): + x_np = np.random.rand(2, 2).astype("float32") + c_np = np.random.rand(2, 2).astype("float32") + + bb = relax.BlockBuilder() + x = relax.Var("x", (2, 2), relax.DynTensorType(2, "float32")) + c = relax.const(c_np, "float32") + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, c) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + mod = bb.get() + exec = relax.vm.build(mod, "llvm") + dev = tvm.cpu() + vm = relax.VirtualMachine(exec, dev) + + add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev)) + tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) + + +@tvm.testing.requires_gpu +def test_vm_emit_te_constant_param_gpu(): + x_np = np.random.rand(2, 2).astype("float32") + c_np = np.random.rand(2, 2).astype("float32") + + bb = relax.BlockBuilder() + x = relax.Var("x", (2, 2), relax.DynTensorType(2, "float32")) + c = relax.const(c_np, "float32") + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, c) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + mod = bb.get() + sch = tvm.tir.Schedule(mod, debug_mask="all") + loops = sch.get_loops(sch.get_block(name="T_add", func_name="add")) + sch.bind(loops[0], "threadIdx.x") + + exec = relax.vm.build(sch.mod, "cuda") + dev = tvm.cuda() + vm = relax.VirtualMachine(exec, dev) + + add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev)) + tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) + + +def test_vm_relax_symbolic_shape(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + type_anno = relax.DynTensorType(1, "float32") + x = relax.Var("x", [n], type_anno) + y = relax.Var("y", [(n // 2) + 1], type_anno) + + def te_func(A, B): + C = te.compute((n,), lambda i: A[i] + B[i // 2]) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x, y) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape1 = (5,) + shape2 = (3,) + inp = tvm.nd.array(np.random.rand(*shape1).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(*shape2).astype(np.float32)) + res = check_saved_func(vm, "rx_func", inp, inp2) + + def expected_output(): + return inp.numpy() + np.repeat(inp2.numpy(), 2)[:5] + + tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) + + +def test_vm_relax_dyn_tir_shape(): + # case where TIR variables are unbound in generated PrimFunc + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + + def te_func(A): + C = te.compute((n + 1), lambda i: A[i]) + return C + + with bb.function("rx_func"): + x = nn.Placeholder((n,), dtype="float32", name="x") + y = nn.Placeholder((n + 1,), dtype="float32", name="y") + + x1 = bb.emit_te(te_func, y) + bb.emit_func_output(x1, params=[x, y]) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + + ex.mod.export_library("exec.so") + exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so")) + os.remove("exec.so") + assert ex.as_text() == exec1.as_text() + + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(2).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(3).astype(np.float32)) + + res = check_saved_func(vm, "rx_func", inp, inp2) + + tvm.testing.assert_allclose(res.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_tuple(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + + with bb.function("rx_func"): + x = nn.Placeholder((n,), dtype="float32", name="x") + y = nn.Placeholder((n,), dtype="float32", name="y") + tup = relax.Tuple([x, y]) + item = tup[0] + bb.emit_func_output([tup, item], params=[x, y]) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape = (5, 5) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + (res1, res2), res3 = vm["rx_func"](inp, inp2) + + tvm.testing.assert_allclose(res1.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res2.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res3.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_tuplegetitem(): + @tvm.script.ir_module + class TestVMTupleGetItem: + @R.function + def tuple_get_item(x: Tensor((_, _), "float32"), y: Tensor((_, _), "float32")): + t = (x, y) + a = t[0] + b = t[1] + c = relax.call_packed("test.vm.add", a, b, type_args=(Tensor(ndim=2, dtype="float32"))) + return c + + mod = TestVMTupleGetItem + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(2, 3)) + y_inp = tvm.nd.array(np.random.rand(2, 3)) + res = check_saved_func(vm, "tuple_get_item", x_inp, y_inp) + tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy(), rtol=1e-7, atol=1e-7) + + +def test_vm_print_const(): + @tvm.script.ir_module + class PrintConst: + @R.function + def main(): + x = relax.const([1, 2]) + y = relax.print(x) + return x + + try: + stdout = sys.stdout + with tempfile.TemporaryFile(mode="w+") as test_out: + sys.stdout = test_out + mod = PrintConst + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"]() + test_out.seek(0) + printed_text = str(test_out.read()) + expected = "[1 2]\n" + assert printed_text == expected + tvm.testing.assert_allclose(res.numpy(), np.array([1, 2])) + finally: + sys.stdout = stdout + + +def test_vm_return_const_tuple(): + @tvm.script.ir_module + class ReturnConstTuple: + @R.function + def main(x: Tensor((_, _), "float32")): + y = relax.const([1, 2]) + z = (y, relax.const([3, 4]), x) + return z + + mod = ReturnConstTuple + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(2, 3)) + res0, res1, res2 = vm["main"](inp) + tvm.testing.assert_allclose(res0.numpy(), np.array([1, 2])) + tvm.testing.assert_allclose(res1.numpy(), np.array([3, 4])) + tvm.testing.assert_allclose(res2.numpy(), inp.numpy()) + + +def test_vm_const_as_call_arg(): + @tvm.script.ir_module + class TestVMConstAsCallArg: + @R.function + def main(x: Tensor((_, _), "float32")): + a = relax.call_packed( + "test.vm.add", + relax.const([1, 2]), + relax.const([3, 4]), + type_args=(Tensor(ndim=2, dtype="float32")), + ) + b = relax.call_packed( + "test.vm.add", + a, + x, + type_args=(Tensor(ndim=2, dtype="float32")), + ) + return b + + mod = TestVMConstAsCallArg + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(1, 2)) + res = vm["main"](inp) + tvm.testing.assert_allclose(res.numpy(), np.array([4, 6]) + inp.numpy()) + + +def test_vm_if_cond_const(): + @tvm.script.ir_module + class TestVMIfCondConst: + @R.function + def main(x: Tensor((_, _), "float32")) -> Tensor((1,), "int32"): + if relax.const(True, dtype="bool"): + ret = x + else: + ret = x + return ret + + mod = TestVMIfCondConst + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(3, 4)) + res = vm["main"](inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy()) + + +def test_sub_func_call(): + @tvm.script.ir_module + class TestVMSubFunction: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def relax_matmul_tir( + x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32") + ) -> Tensor: + with R.dataflow(): + gv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") + R.output(gv0) + return gv0 + + @R.function + def relax_matmul_packed( + x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32") + ) -> Object: + gv0 = relax.call_packed( + "test.vm.mul", x, w, type_args=(Tensor(ndim=2, dtype="float32")) + ) + return gv0 + + @R.function + def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Object: + gv0 = relax_matmul_tir(x, w) + gv1 = relax_matmul_packed(gv0, gv0, type_args=(Tensor(ndim=2, dtype="float32"))) + return gv1 + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestVMSubFunction, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) + y_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) + res = check_saved_func(vm, "main", x_inp, y_inp) + product = np.dot(x_inp.numpy(), y_inp.numpy()) + expected = product * product + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +def test_recursion(): + @tvm.script.ir_module + class TestVMRecursion: + @R.function + def recursion(n: Tensor((1,), "float32")) -> Tensor: + cond = relax.call_packed( + "test.vm.equal_zero", n, type_args=(Tensor(ndim=1, dtype="float32")) + ) + if cond: + res = relax.const(1.0) + else: + gv0 = relax.call_packed( + "test.vm.subtract_one", n, type_args=(Tensor(ndim=1, dtype="float32")) + ) + tmp = recursion(gv0) + res = relax.call_packed( + "test.vm.add", tmp, tmp, type_args=(Tensor(ndim=1, dtype="float32")) + ) + return res + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestVMRecursion, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + inp = np.empty(1) + recursion_runs = np.random.randint(1, 10) + inp.fill(recursion_runs) + inp = tvm.nd.array(inp) + res = check_saved_func(vm, "recursion", inp) + tvm.testing.assert_allclose(res.numpy(), np.power(2.0, recursion_runs), rtol=1e-7, atol=1e-7) + + +def test_vm_closure(): + @tvm.script.ir_module + class TestClosure: + @R.function + def lifted_func_1(x: Tensor((2, 3), "float32"), env: Tensor((2, 3), "float32")): + return relax.call_packed("test.vm.add", x, env, type_args=(Tensor)) + + @R.function + def main( + x: Tensor((2, 3), "float32"), + y: Tensor((2, 3), "float32"), + ): + clo = relax.make_closure(lifted_func_1, (x,)) + res = relax.invoke_closure(clo, (y,), type_args=(Tensor)) + return res + + mod = TestClosure + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(2, 3)) + y_inp = tvm.nd.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]]) + res = check_saved_func(vm, "main", x_inp, y_inp) + tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy()) + + +def test_vm_invoke_closure(): + ib = relax.ExecBuilder() + with ib.function("lifted_func_1", num_inputs=4): + ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(4)) + ib.emit_call("test.vm.add", args=[ib.r(2), ib.r(4)], dst=ib.r(5)) + ib.emit_call("test.vm.add", args=[ib.r(3), ib.r(5)], dst=ib.r(6)) + ib.emit_ret(ib.r(6)) + with ib.function("main", num_inputs=2): + x = ib.emit_constant("lifted_func_1") + ib.emit_call("vm.builtin.alloc_closure", args=[ib.c(x), ib.r(0), ib.r(1)], dst=ib.r(2)) + ib.emit_ret(ib.r(2)) + + ex = ib.get() + vm = relax.VirtualMachine(ex, tvm.cpu()) + w_inp = tvm.nd.array(np.random.rand(2, 3)) + x_inp = tvm.nd.array(np.random.rand(2, 3)) + y_inp = tvm.nd.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]]) + z_inp = tvm.nd.array(np.random.rand(2, 3)) + clo = vm["main"](w_inp, x_inp) + res = vm.invoke_closure(clo, (y_inp, z_inp)) + tvm.testing.assert_allclose( + res.numpy(), w_inp.numpy() + x_inp.numpy() + y_inp.numpy() + z_inp.numpy() + ) + + +def test_time_evaluator(): + @tvm.script.ir_module + class TestTimeEvaluator: + @R.function + def main(x: Tensor((1,), "float32"), y: Tensor((1,), "float32")): + return R.call_packed("test.vm.add", x, y, type_args=(Tensor(ndim=1, dtype="float32"))) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestTimeEvaluator, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x = tvm.nd.array(np.random.rand(1)) + y = tvm.nd.array(np.random.rand(1)) + + # ensure we can use time_evaluator with the stateful API + vm.set_input("main", x, y) + timing_res = vm.time_evaluator("invoke_stateful", tvm.cpu())("main") + # just checking that it has some results at all + assert timing_res.results + + # ensure we can use it with a closure + vm.save_function("main", "saved_main", x, y) + timing_res = vm.time_evaluator("saved_main", tvm.cpu())() + assert timing_res.results + + +@tvm.script.ir_module +class TestVMSetInput: + @T.prim_func + def test_vm_mul(x: T.handle, y: T.handle, z: T.handle): + T.func_attr({"global_symbol": "test_vm_mul"}) + m = T.var("int32") + n = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (m, n)) + C = T.match_buffer(z, (m, n)) + + for i, j in T.grid(m, n): + with T.block("mul"): + vi = T.axis.spatial(m, i) + vj = T.axis.spatial(n, j) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = A[vi, vj] * B[vi, vj] + + # test returning a tuple + @R.function + def test_vm_tuple(x: Tensor((), "int32")) -> Tuple(Tensor((), "int32"), Tensor((), "int32")): + return (x, x) + + # nested tuple too + @R.function + def test_vm_nested_tuple( + x: Tensor((), "int32") + ) -> Tuple(Tuple(Tensor((), "int32"), Tuple(Tensor((), "int32"),)), Tensor((), "int32")): + return ((x, (x,)), x) + + @R.function + def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: + gv0 = R.call_tir("test_vm_mul", (x, w), (32, 32), dtype="float32") + return gv0 + + +def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + vm.invoke_stateful("main") + res0 = vm.get_outputs("main") + + data_dict = {"x": a, "w": b} + vm.set_input("main", **data_dict) + vm.invoke_stateful("main") + res1 = vm.get_outputs("main") + tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res0.numpy(), res1.numpy(), rtol=1e-7, atol=1e-7) + + # bug! If you don't bind the NDArray to a var, the memory will get corrupted. + # Possibly due to object lifecycles and other FFI issues + a = tvm.nd.array(2, device) + vm.set_input("test_vm_tuple", a) + vm.invoke_stateful("test_vm_tuple") + res2 = vm.get_outputs("test_vm_tuple") + # the results are NDArrays wrapped around scalars, + # so we have to get the scalar out of the NDArray + assert tuple(map(lambda a: int(a.numpy()), res2)) == (2, 2) + + b = tvm.nd.array(1, device) + vm.set_input("test_vm_nested_tuple", b) + vm.invoke_stateful("test_vm_nested_tuple") + res3 = vm.get_outputs("test_vm_nested_tuple") + assert len(res3) == 2 and len(res3[0]) == 2 and len(res3[0][1]) == 1 + result_cast = ((int(res3[0][0].numpy()), (int(res3[0][1][0].numpy()),)), int(res3[1].numpy())) + assert result_cast == ((1, (1,)), 1) + + +def set_input_attempt_stateless(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: once you set inputs, you cannot run statelessly + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + # must use invoke stateful! + vm["main"]() + + +def set_input_attempt_invoke(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: if the function needs inputs, you can't invoke directly + vm.invoke_stateful("main") + + +def set_input_attempt_get(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: you can't get outputs without invoking the function first + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + _ = vm.get_outputs("main") + + +def make_vm(mod) -> Tuple[relax.VirtualMachine, tvm.runtime.Device]: + """Returns a local VM for the given mod and the device""" + target = tvm.target.Target("llvm", host="llvm") + exec = relax.vm.build(TestVMSetInput, target) + exec.mod.export_library("exec.so") + exec_loaded = relax.vm.Executable(tvm.runtime.load_module("exec.so")) + os.remove("exec.so") + device = tvm.cpu() + return relax.VirtualMachine(exec_loaded, device), device + + +def run_on_rpc( + mod: tvm.IRModule, trial_func: Callable[[relax.VirtualMachine, tvm.runtime.Device], None] +): + """ + Sets up a VM over localhost using the given mod and runs the given trial function. + The trial function should take a VM and a device + """ + target = tvm.target.Target("llvm", host="llvm") + exec = relax.vm.build(mod, target) + temp = utils.tempdir() + path = temp.relpath("vm_library.so") + exec.mod.export_library(path) + + # Use local rpc server for testing. + # Server must use popen so it doesn't inherit the current process state. It + # will crash otherwise. + # Adapted from relay/test_vm.py + def check_remote(server): + remote = rpc.connect(server.host, server.port, session_timeout=10) + + # Upload the serialized Executable. + remote.upload(path) + # Get a handle to remote Executable. + rexec = remote.load_module("vm_library.so") + + device = remote.cpu() + # Build a VM out of the executable and context. + vm = relax.vm.VirtualMachine(exec=rexec, device=device) + trial_func(vm, device) + + check_remote(rpc.Server("127.0.0.1")) + + +def test_set_input(): + set_input_trial(*make_vm(TestVMSetInput)) + + +def test_set_input_rpc(): + run_on_rpc(TestVMSetInput, set_input_trial) + + +def save_function_kwargs_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # just checking that we can use kwargs for the args when saving a function + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.save_function("main", "saved_main", x=a, w=b) + res0 = vm["saved_main"]() + tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + + +def test_save_function_kwargs(): + save_function_kwargs_trial(*make_vm(TestVMSetInput)) + + +def test_save_function_kwargs_rpc(): + run_on_rpc(TestVMSetInput, save_function_kwargs_trial) + + +def save_function_time_evaluator_trial( + vm: relax.VirtualMachine, device: tvm.runtime.Device +) -> None: + # just checking that the saved function can be called in the time evaluator + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.save_function("main", "saved_main", a, b) + vm.time_evaluator("saved_main", device)() + + +def test_save_function_time_evaluator(): + save_function_time_evaluator_trial(*make_vm(TestVMSetInput)) + + +def test_save_function_time_evaluator(): + run_on_rpc(TestVMSetInput, save_function_time_evaluator_trial) + + +# if you set an input, you should not be able to call statelessly +@pytest.mark.xfail() +def test_set_input_stateless_failure(): + set_input_attempt_stateless(*make_vm(TestVMSetInput)) + + +@pytest.mark.xfail() +def test_set_input_stateless_failure_rpc(): + run_on_rpc(TestVMSetInput, set_input_attempt_stateless) + + +@pytest.mark.xfail() +def test_set_input_invoke_failure(): + set_input_attempt_invoke(*make_vm(TestVMSetInput)) + + +@pytest.mark.xfail() +def test_set_input_invoke_failure_rpc(): + run_on_rpc(TestVMSetInput, set_input_attempt_invoke) + + +@pytest.mark.xfail() +def test_set_input_get_failure(): + set_input_attempt_get(*make_vm(TestVMSetInput)) + + +@pytest.mark.xfail() +def test_set_input_get_failure_rpc(): + run_on_rpc(TestVMSetInput, set_input_attempt_get) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index d10fd2d23d..41aa43cb38 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -13,7 +13,7 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# under the License. +# under the License # pylint: disable=missing-function-docstring,missing-module-docstring import numpy as np import tvm @@ -345,6 +345,29 @@ def test_data_dependent_access(): tvm.testing.assert_allclose(a_np[b_np], c.numpy()) +def test_loop_var_datatype(): + def test_helper(dtype): + n = te.var("n", dtype) + A = te.placeholder((n,), name="A") + B = te.placeholder((n,), name="B", dtype="int32") + C = te.compute((n,), lambda i: A[i] + B[i]) + + func = te.create_prim_func([C, A, B]) + + assert func.body.block.body.loop_var.dtype == dtype + + func = tvm.build(func) + + a_np = np.random.uniform(size=(10,)).astype(A.dtype) + b_np = np.random.uniform(size=(10,)).astype(B.dtype) + c = tvm.nd.array(np.zeros(10, dtype=C.dtype)) + func(c, tvm.nd.array(a_np), tvm.nd.array(b_np)) + tvm.testing.assert_allclose(a_np + b_np, c.numpy()) + + test_helper("int32") + test_helper("int64") + + def test_select_simplify(): placeholder = te.placeholder([1, 128, 10, 10, 4], dtype="float32") tensor = topi.nn.adaptive_pool(placeholder, [1, 1], "avg", "NCHW4c") @@ -552,6 +575,91 @@ def expected( _check_workload(te_func, expected) +def test_unbound_var(): + n = tir.Var("n", "int32") + A = te.placeholder((n + 1,), name="A") + B = te.compute((n + 1,), lambda i: A[i], name="B") + func = te.create_prim_func([A, B], [n]) + assert len(func.params) == 3 + assert func.params[2] == n + + func = tvm.build(func) + + a_np = np.random.uniform(size=(10,)).astype(A.dtype) + b = tvm.nd.array(np.zeros(10, dtype=B.dtype)) + func(tvm.nd.array(a_np), b, 9) + tvm.testing.assert_allclose(a_np, b.numpy()) + + +def te_argmax(): + # x and y are the operands of reduction, both of them is a tuple of index + # and value. + def fcombine(x, y): + lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) + return lhs, rhs + + # our identity element also need to be a tuple, so `fidentity` accepts + # two types as inputs. + def fidentity(t0, t1): + return tvm.tir.const(-1, t0), tvm.te.min_value(t1) + + argmax = te.comm_reducer(fcombine, fidentity, name="argmax") + + # describe the reduction computation + m = te.var("m") + n = te.var("n") + idx = te.placeholder((m, n), name="idx", dtype="int32") + val = te.placeholder((m, n), name="val", dtype="int32") + k = te.reduce_axis((0, n), "k") + T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T") + return [idx, val, T0, T1] + + +@T.prim_func +def tir_argmax( + var_idx: T.handle, var_val: T.handle, var_T_v0: T.handle, var_T_v1: T.handle +) -> None: + m = T.var("int32") + n = T.var("int32") + idx = T.match_buffer(var_idx, [m, n], dtype="int32") + val = T.match_buffer(var_val, [m, n], dtype="int32") + T_v0 = T.match_buffer(var_T_v0, [m], dtype="int32") + T_v1 = T.match_buffer(var_T_v1, [m], dtype="int32") + # body + # with T.block("root") + for i0, i1 in T.grid(m, n): + with T.block("T.v0"): + i, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_v0[i] = -1 + T_v1[i] = -2147483648 + T_v0[i] = T.Select(T_v1[i] >= val[i, k], T_v0[i], idx[i, k]) + T_v1[i] = T.Select(T_v1[i] >= val[i, k], T_v1[i], val[i, k]) + + +def test_argmax(): + _check_workload(te_argmax, tir_argmax) + + dtype = "int32" + func = te.create_prim_func(te_argmax()) + assert len(func.params) == 4 + + func = tvm.build(func) + + idx_np = np.arange(100, dtype=dtype).reshape((10, 10)) + val_np = np.random.permutation(100).reshape((10, 10)).astype(dtype) + c = tvm.nd.array(np.zeros(10, dtype=dtype)) # argmax index + d = tvm.nd.array(np.zeros(10, dtype=dtype)) # max value + func(tvm.nd.array(idx_np), tvm.nd.array(val_np), c, d) + + c_expected = idx_np[np.arange(10), np.argmax(val_np, axis=1)] + d_expected = np.amax(val_np, axis=1) + + tvm.testing.assert_allclose(c_expected, c.numpy()) + tvm.testing.assert_allclose(d_expected, d.numpy()) + + if __name__ == "__main__": test_unique_name_complete_block() test_unique_name_reduction_block() @@ -570,3 +678,6 @@ def expected( test_argmax_val_idx() test_int64_indices() test_zero_dim_add() + test_loop_var_datatype() + test_unbound_var() + test_argmax() diff --git a/tests/python/unittest/test_tvmscript_printer_highlight.py b/tests/python/unittest/test_tvmscript_printer_highlight.py index cc3469a2ce..a8a7354371 100644 --- a/tests/python/unittest/test_tvmscript_printer_highlight.py +++ b/tests/python/unittest/test_tvmscript_printer_highlight.py @@ -14,34 +14,42 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import pytest import tvm -from tvm.script import tir as T +from tvm.script import tir as T, relax as R def test_highlight_script(): @tvm.script.ir_module class Module: @T.prim_func - def main( # type: ignore - a: T.handle, - b: T.handle, - c: T.handle, - ) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [16, 128, 128]) - B = T.match_buffer(b, [16, 128, 128]) - C = T.match_buffer(c, [16, 128, 128]) - for n, i, j, k in T.grid(16, 128, 128, 128): - with T.block("matmul"): - vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + k = T.var("int32") + A = T.match_buffer(x, (32, 32)) + B = T.match_buffer(y, (32, 32)) + C = T.match_buffer(z, (32, 32)) + + for (i0, j0, k0) in T.grid(32, 32, 32): + with T.block(): + i, j, k = T.axis.remap("SSR", [i0, j0, k0]) with T.init(): - C[vn, vi, vj] = 0.0 # type: ignore - C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + C[i, j] = 0.0 + C[i, j] += A[i, k] * B[j, k] + + @R.function + def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: + with R.dataflow(): + lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype="float32") + R.output(lv0) + return lv0 Module.show() Module["main"].show() + Module["tir_matmul"].show() Module["main"].show(style="light") Module["main"].show(style="dark") Module["main"].show(style="ansi") diff --git a/tests/scripts/task_build.py b/tests/scripts/task_build.py index 1a8a1d112f..66e2b4210d 100755 --- a/tests/scripts/task_build.py +++ b/tests/scripts/task_build.py @@ -76,17 +76,23 @@ available_cpus = nproc // executors num_cpus = max(available_cpus, 1) - sh.run("cmake -GNinja -DCMAKE_BUILD_TYPE=RelWithDebInfo ..", cwd=build_dir) - + # sh.run("cmake -GNinja -DCMAKE_BUILD_TYPE=RelWithDebInfo ..", cwd=build_dir) + # + # target = "" + # if args.cmake_target: + # target = args.cmake_target + # + # verbose = os.environ.get("VERBOSE", "true").lower() in {"1", "true", "yes"} + # ninja_args = [target, f"-j{num_cpus}"] + # if verbose: + # ninja_args.append("-v") + # sh.run(f"cmake --build . -- " + " ".join(ninja_args), cwd=build_dir) + + sh.run("cmake -DCMAKE_BUILD_TYPE=RelWithDebInfo ..", cwd=build_dir) target = "" if args.cmake_target: target = args.cmake_target - - verbose = os.environ.get("VERBOSE", "true").lower() in {"1", "true", "yes"} - ninja_args = [target, f"-j{num_cpus}"] - if verbose: - ninja_args.append("-v") - sh.run(f"cmake --build . -- " + " ".join(ninja_args), cwd=build_dir) + sh.run(f"cmake --build . -- {target} VERBOSE=1 -j{num_cpus}", cwd=build_dir) if use_sccache: logging.info("===== sccache stats =====") diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index e3d8aa9a1d..90b89ba8cd 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -23,40 +23,32 @@ mkdir -p "$BUILD_DIR" cd "$BUILD_DIR" cp ../cmake/config.cmake . -echo set\(USE_SORT ON\) >> config.cmake -echo set\(USE_MICRO ON\) >> config.cmake -echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake -echo set\(USE_PROFILER ON\) >> config.cmake -echo set\(USE_DNNL ON\) >> config.cmake -echo set\(USE_ARM_COMPUTE_LIB ON\) >> config.cmake +# echo set\(USE_SORT ON\) >> config.cmake +# echo set\(USE_MICRO ON\) >> config.cmake +# echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake +# echo set\(USE_PROFILER ON\) >> config.cmake +# echo set\(USE_DNNL_CODEGEN ON\) >> config.cmake +# echo set\(USE_ARM_COMPUTE_LIB ON\) >> config.cmake echo set\(USE_LLVM llvm-config-11\) >> config.cmake -echo set\(USE_NNPACK ON\) >> config.cmake -echo set\(NNPACK_PATH /NNPACK/build/\) >> config.cmake -echo set\(USE_ANTLR ON\) >> config.cmake +echo set\(USE_BLAS openblas\) >> config.cmake +# echo set\(USE_NNPACK ON\) >> config.cmake +# echo set\(NNPACK_PATH /NNPACK/build/\) >> config.cmake +# echo set\(USE_ANTLR ON\) >> config.cmake +echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake -echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake -echo set\(USE_VTA_TSIM ON\) >> config.cmake -echo set\(USE_VTA_FSIM ON\) >> config.cmake - -# This conditional is just to support the transition to cope -# with the change in the way TFLite is built. It can be -# removed once we migrate to TensorFlow and TFLite > 2.9.1 -if [ -d "/opt/tflite" ]; then - echo set\(USE_TFLITE \"/opt/tflite\"\) >> config.cmake -else - echo set\(USE_TFLITE ON\) >> config.cmake -fi - -echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake -echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake -echo set\(USE_ETHOSN /opt/arm/ethosn-driver\) >> config.cmake -echo set\(USE_ETHOSN_HW OFF\) >> config.cmake -echo set\(USE_CMSISNN OFF\) >> config.cmake -echo set\(USE_VITIS_AI ON\) >> config.cmake -echo set\(USE_VERILATOR ON\) >> config.cmake -echo set\(USE_LIBBACKTRACE ON\) >> config.cmake -echo set\(BACKTRACE_ON_SEGFAULT ON\) >> config.cmake +# echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake +# echo set\(USE_VTA_TSIM ON\) >> config.cmake +# echo set\(USE_VTA_FSIM ON\) >> config.cmake +# echo set\(USE_TFLITE ON\) >> config.cmake +# echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake +# echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake +# echo set\(USE_ETHOSN /opt/arm/ethosn-driver\) >> config.cmake +# echo set\(USE_ETHOSN_HW OFF\) >> config.cmake +# echo set\(USE_CMSISNN ON\) >> config.cmake +# echo set\(USE_VITIS_AI ON\) >> config.cmake +# echo set\(USE_VERILATOR ON\) >> config.cmake +# echo set\(USE_LIBBACKTRACE ON\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake -echo set\(USE_ETHOSU OFF\) >> config.cmake -echo set\(USE_UMA ON\) >> config.cmake -echo set\(SUMMARIZE ON\) >> config.cmake +echo set\(USE_ETHOSU ON\) >> config.cmake +# echo set\(USE_UMA ON\) >> config.cmake +# echo set\(SUMMARIZE ON\) >> config.cmake diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index 84f4652337..61d651ae9a 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -30,9 +30,11 @@ trap cleanup 0 function shard1 { echo "Convert scripts to Python..." tests/scripts/task_convert_scripts_to_python.sh - - echo "Check Jenkinsfile generation" - python3 ci/jenkins/generate.py --check + # TODO: Remove this ad-hoc pip install once https://github.com/apache/tvm/pull/10741 + # is added to the ci_lint Docker image + # python3 -m pip install --user -r jenkins/requirements.txt + # echo "Check Jenkinsfile generation" + # python3 jenkins/generate.py --check echo "Checking file types..." python3 tests/lint/check_file_type.py @@ -44,16 +46,7 @@ function shard1 { python3 tests/lint/check_request_hook.py echo "black check..." - tests/lint/git-black.sh - - echo "Linting the Python code with flake8..." - tests/lint/flake8.sh - - echo "Type checking with MyPy ..." - tests/scripts/task_mypy.sh - - echo "Checking for non-inclusive language with blocklint..." - tests/lint/blocklint.sh + tests/lint/git-black.sh --rev HEAD~5 echo "Linting the JNI code..." tests/lint/jnilint.sh @@ -62,9 +55,10 @@ function shard1 { function shard2 { echo "Linting the Python code with pylint..." tests/lint/pylint.sh + tests/lint/flake8.sh - echo "Checking C++ documentation..." - tests/lint/cppdocs.sh + # echo "Checking C++ documentation..." + # tests/lint/cppdocs.sh echo "Checking ASF license headers..." tests/lint/check_asf_header.sh --local @@ -75,11 +69,11 @@ function shard2 { echo "clang-format check..." tests/lint/git-clang-format.sh - echo "Rust check..." - tests/lint/rust_format.sh + # echo "Rust check..." + # tests/lint/rust_format.sh - echo "Docker check..." - tests/lint/docker-format.sh + # echo "Docker check..." + # tests/lint/docker-format.sh } @@ -93,3 +87,7 @@ else shard1 shard2 fi + +#TODO(@yuchen) fix mypy in relax +# echo "Type checking with MyPy ..." +# tests/scripts/task_mypy.sh diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index c3e5d50b3e..1450bf1d99 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -20,23 +20,20 @@ set -euxo pipefail source tests/scripts/setup-pytest-env.sh -echo "Checking MyPy Type defs in the TensorIR schedule package." -mypy --check-untyped-defs python/tvm/tir/schedule +# echo "Checking MyPy Type defs in the TensorIR schedule package." +# mypy --check-untyped-defs python/tvm/tir/schedule -echo "Checking MyPy Type defs in the meta schedule package." -mypy --check-untyped-defs python/tvm/meta_schedule +# echo "Checking MyPy Type defs in the meta schedule package." +# mypy --check-untyped-defs python/tvm/meta_schedule -echo "Checking MyPy Type defs in the analysis package." -mypy --check-untyped-defs python/tvm/tir/analysis/ +# echo "Checking MyPy Type defs in the analysis package." +# mypy --check-untyped-defs python/tvm/tir/analysis/ -echo "Checking MyPy Type defs in the transform package." -mypy --check-untyped-defs python/tvm/tir/transform/ +# echo "Checking MyPy Type defs in the transform package." +# mypy --check-untyped-defs python/tvm/tir/transform/ -echo "Checking MyPy Type defs in the tvmscript printer package." -mypy --check-untyped-defs python/tvm/script/printer - -echo "Checking MyPy Type defs in the TIR package with unittest" -MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py +# echo "Checking MyPy Type defs in the TIR package with unittest" +# MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py echo "Checking MyPy Type defs in tvm.relay.op.contrib" mypy --disallow-untyped-defs python/tvm/relay/op/contrib/cublas.py @@ -48,5 +45,8 @@ mypy --disallow-untyped-defs python/tvm/relay/op/contrib/tensorrt.py # echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package." # mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/ -echo "Checking MyPy Type defs in the tvmscript IRBuilder package." -mypy --check-untyped-defs python/tvm/script/ir_builder +# echo "Checking MyPy Type defs in the tvmscript IRBuilder package." +# mypy --check-untyped-defs python/tvm/script/ir_builder + +# echo "Checking MyPy Type defs in the relax package." +# mypy --check-untyped-defs python/tvm/relax/ diff --git a/tests/scripts/task_python_hexagon.sh b/tests/scripts/task_python_hexagon.sh index ba125b161a..abbc287a98 100755 --- a/tests/scripts/task_python_hexagon.sh +++ b/tests/scripts/task_python_hexagon.sh @@ -48,12 +48,17 @@ if [ ! "${device_serial}" == "simulator" ]; then fi export ANDROID_SERIAL_NUMBER=${device_serial} -if [ "${device_serial}" == "simulator" ]; then - run_pytest ctypes python-contrib-hexagon tests/python/contrib/test_hexagon -else - run_pytest ctypes python-contrib-hexagon tests/python/contrib/test_hexagon -n=$num_of_devices + +# Only test integration with Relax +# TODO(prakalp): Run the same tests on simulator and device once the bug with device is fixed. +if [[ "${device_serial}" == "simulator" ]]; + then + run_pytest ctypes python-contrib-hexagon tests/python/contrib/test_hexagon/test_relax_integration.py + else + run_pytest ctypes python-contrib-hexagon tests/python/contrib/test_hexagon/test_relax_integration.py::test_conv2d fi + if [[ "${device_serial}" == "simulator" ]]; then kill ${TRACKER_PID} fi diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 5eac7b45ba..30848f3355 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -26,54 +26,61 @@ export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}" export TVM_BIND_THREADS=0 export TVM_NUM_THREADS=2 +# Run Relax tests +TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/relax + +# Run Relax examples +python3 ./apps/relax_examples/mlp.py +python3 ./apps/relax_examples/nn_module.py +python3 ./apps/relax_examples/resnet.py + # NOTE: also set by task_python_integration_gpuonly.sh. -if [ -z "${TVM_INTEGRATION_TESTSUITE_NAME:-}" ]; then - TVM_INTEGRATION_TESTSUITE_NAME=python-integration -fi +# if [ -z "${TVM_INTEGRATION_TESTSUITE_NAME:-}" ]; then +# TVM_INTEGRATION_TESTSUITE_NAME=python-integration +# fi # cleanup pycache -find . -type f -path "*.pyc" | xargs rm -f +# find . -type f -path "*.pyc" | xargs rm -f # Test TVM -make cython3 +# make cython3 # Test extern package -cd apps/extension -rm -rf lib -make -cd ../.. - -run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-extensions-0 apps/extension/tests -run_pytest cython ${TVM_INTEGRATION_TESTSUITE_NAME}-extensions-1 apps/extension/tests - -# Test dso plugin -cd apps/dso_plugin_module -rm -rf lib -make -cd ../.. -run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-dso_plugin_module-0 apps/dso_plugin_module -run_pytest cython ${TVM_INTEGRATION_TESTSUITE_NAME}-dso_plugin_module-1 apps/dso_plugin_module +# cd apps/extension +# rm -rf lib +# make +# cd ../.. + +# run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-extensions apps/extension/tests +# run_pytest cython ${TVM_INTEGRATION_TESTSUITE_NAME}-extensions apps/extension/tests + +# # Test dso plugin +# cd apps/dso_plugin_module +# rm -rf lib +# make +# cd ../.. +# run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-dso_plugin_module apps/dso_plugin_module +# run_pytest cython ${TVM_INTEGRATION_TESTSUITE_NAME}-dso_plugin_module apps/dso_plugin_module # Do not enable TensorFlow op # TVM_FFI=cython sh prepare_and_test_tfop_module.sh # TVM_FFI=ctypes sh prepare_and_test_tfop_module.sh -run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-integration tests/python/integration +# run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME} tests/python/integration +# if python3 -c "import tvm; from tvm.relay.op.contrib.ethosn import ethosn_available; print(ethosn_available().name)" -eq "SW_ONLY"; then +# ETHOSN_VARIANT_CONFIG=Ethos-N78_1TOPS_2PLE_RATIO run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-contrib-test_ethosn tests/python/contrib/test_ethosn +# fi +# run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-contrib tests/python/contrib -# Ignoring Arm(R) Ethos(TM)-U NPU tests in the collective to run to run them in parallel in the next step. -run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-contrib tests/python/contrib --ignore=tests/python/contrib/test_ethosu --ignore=tests/python/contrib/test_cmsisnn # forked is needed because the global registry gets contaminated -TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm;cuda}" \ - run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-relay tests/python/relay +# TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm;cuda}" \ +# run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-relax tests/python/relax -# OpenCL texture test. Deselected specific tests that fails in CI -TVM_TEST_TARGETS="${TVM_RELAY_OPENCL_TEXTURE_TARGETS:-opencl}" \ - run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-opencl-texture tests/python/relay/opencl_texture # Command line driver test -run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-driver tests/python/driver +# run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-driver tests/python/driver # Target test -run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-target tests/python/target +# run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-target tests/python/target # Do not enable OpenGL # run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-webgl tests/webgl diff --git a/tests/scripts/task_python_integration_gpuonly.sh b/tests/scripts/task_python_integration_gpuonly.sh index 432984c955..07a94c725b 100755 --- a/tests/scripts/task_python_integration_gpuonly.sh +++ b/tests/scripts/task_python_integration_gpuonly.sh @@ -16,12 +16,9 @@ # specific language governing permissions and limitations # under the License. -set -exo pipefail - -export TVM_TEST_TARGETS="cuda;opencl;metal;rocm;nvptx;opencl -device=mali,aocl_sw_emu,adreno" +export TVM_TEST_TARGETS="llvm;cuda" export PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" export TVM_RELAY_TEST_TARGETS="cuda" -export TVM_RELAY_OPENCL_TEXTURE_TARGETS="opencl -device=adreno" export TVM_INTEGRATION_TESTSUITE_NAME=python-integration-gpu export TVM_INTEGRATION_GPU_ONLY=1