Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Unity][TVMScript] return from IfNode #14176

Closed
wants to merge 81 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
9745433
[Unity] Relax VM (#13878)
YuchenJin Feb 1, 2023
2f96da7
[Unity] Relax expressions and types (#13901)
YuchenJin Feb 2, 2023
ecbd0a4
[Unity][IR] First-class StructInfo (#13907)
YuchenJin Feb 3, 2023
76cc9f7
[Unity][CI] Unity specific jenkins setup (do not upstream to main) (#…
tqchen Feb 3, 2023
fa561c8
[Unity] Basic StructInfo Analysis and Expr construction (#13916)
YuchenJin Feb 5, 2023
ff488e9
[Unity] Relax BlockBuilder and ExprMutator (#13926)
YuchenJin Feb 7, 2023
450d2a7
[Unity] Relax TVMScript Parser. (#13932)
Hzfengsy Feb 8, 2023
409bf91
[Unity] Relax TVMScript Printer (#13944)
junrushao Feb 10, 2023
5095764
[Unity] Relax VM codegen (#13954)
YuchenJin Feb 11, 2023
55c2d1f
[Unity] Relax VM shape lowering pass (#13956)
YuchenJin Feb 11, 2023
9e47ae6
[Unity] e2e Relax minimum build flow (#13961)
YuchenJin Feb 11, 2023
819c720
[Unity][TVMScript] Use explicit `R.shape` in TVMScript (#13979)
Hzfengsy Feb 14, 2023
e48d4d2
[Unity] Relax op: index (#13987)
MasterJH5574 Feb 14, 2023
886689a
[Unity] Relax op: datatype (#13986)
MasterJH5574 Feb 14, 2023
20ca7c0
[Unity] Relax op: set (#13990)
MasterJH5574 Feb 14, 2023
c06d16f
[Unity] Relax op: image (#13994)
MasterJH5574 Feb 14, 2023
27dde56
[Unity] Relax op: arithmetic, comparison (#13983)
MasterJH5574 Feb 14, 2023
d4a7cfc
[Unity] Relax op: statistical (#13991)
MasterJH5574 Feb 14, 2023
fcf4f59
[Unity] Relax op: neural networks (#13993)
MasterJH5574 Feb 14, 2023
35f17cf
[Unity] Relax op: creation (#13984)
MasterJH5574 Feb 14, 2023
b95a20a
[Unity] Relax op: linear algebra (#13988)
MasterJH5574 Feb 14, 2023
4577c98
[Unity] Relax op: search (#13992)
MasterJH5574 Feb 14, 2023
2470435
[Unity] Relax op: manipulation (#13989)
MasterJH5574 Feb 14, 2023
fe81dda
[Unity] NestedMsg Support utility (#13995)
tqchen Feb 14, 2023
75b9057
[Unity][Pass] Operator Fusion Passes (#14001)
Hzfengsy Feb 15, 2023
6475d98
[Unity][Pass] LambdaLift pass (#14012)
yongwww Feb 16, 2023
af63d19
[Unity][VM] Supporting "compiled" exec mode. (#14015)
tqchen Feb 17, 2023
b06d779
[Unity][Pass] BindParams pass, FoldConstant pass (#14016)
sunggg Feb 17, 2023
2aed169
[Unity][Pass][TuningAPI] Introduce TuningAPI and MetaSchedule pass (#…
sunggg Feb 17, 2023
eeb40ac
[Unity] Relay -> Relax translator (#14026)
YuchenJin Feb 17, 2023
47722e3
[Unity][Pass] Normalize Pass (#14031)
LeshengJin Feb 18, 2023
bccae02
[Unity][BlockBuilder] CallTE convert PrimValue args (#14028)
MasterJH5574 Feb 18, 2023
2d9fcfa
[Unity][Pass] Wellformed Analysis (#14032)
LeshengJin Feb 18, 2023
8d05dce
[Unity][TVMScript] Move tir/relax import in script out of __init__.py…
MasterJH5574 Feb 18, 2023
7ccda25
[Unity][Pass] Operator legalization (#14029)
MasterJH5574 Feb 18, 2023
b2e46d0
[Unity][Op] Add ShapeExpr Tests for Reshape Op (#14035)
Ubospica Feb 18, 2023
f8ad784
[Unity] Initial PyTorch Frontend (#14037)
MasterJH5574 Feb 18, 2023
df0e043
[Unity][Pass] Block-level static memory planning (#14038)
MasterJH5574 Feb 18, 2023
ff84737
[Unity] Disallow inline prim_func in relax IR (#14040)
yongwww Feb 18, 2023
7d2296f
[Unity] Update tests to adapt to latest TVMScript syntax (#14039)
MasterJH5574 Feb 18, 2023
ef3524a
[Unity] Relax dataflow pattern language (matching) (#14041)
ganler Feb 18, 2023
988b2aa
[Unity] Statement rewriter for DataflowBlock (#14043)
ganler Feb 19, 2023
6316644
[Unity][Pass] FuseOps FuseTIR fixes (#14044)
MasterJH5574 Feb 19, 2023
166bb92
[Unity][TVMScript] Overload `__neg__` for relax expr (#14045)
SiriusNEO Feb 19, 2023
6f4ca6b
[Unity][VM] Add per-op profiling support (#14053)
masahi Feb 20, 2023
be1cc69
[Unity][BYOC] Add pattern-based partitioning pass (#14054)
masahi Feb 20, 2023
6d5f6f0
[Unity] Relax op: collapse sum (#14059)
SiriusNEO Feb 21, 2023
93cf087
[Unity][Fix][Pass] Fix FuseOps for lack graph edges (#14058)
MasterJH5574 Feb 21, 2023
f514905
[Unity][Pass] Remove Unused Function (#14061)
sunggg Feb 21, 2023
8083332
[Unity][BYOC] Add pass to merge composite functions to offload large …
masahi Feb 21, 2023
c575220
[Unity][Frontend] Annotate number of non-static input of FX function …
vinx13 Feb 21, 2023
9be900b
[Unity][Transform] Add LiftTransformParams pass (#14069)
vinx13 Feb 21, 2023
5eee3af
[Unity][BYOC][Pass] RunCodegen and TensorRT (#14078)
sunggg Feb 22, 2023
69cf869
[Unity][Pass] Canonicalize Bindings (#14079)
YuchenJin Feb 22, 2023
a40f1da
[Unity] Add testcases for `expr_args_converter` (#14080)
Hzfengsy Feb 22, 2023
59692e7
[Unity][BYOC] Add CUTLASS backend (#14081)
masahi Feb 22, 2023
cdd61cd
[Unity][BYOC] Add DNNL backend (#14082)
masahi Feb 22, 2023
e7354e6
[Unity][Op] `log_softmax` and `cross_entropy_with_logits` (#14083)
SiriusNEO Feb 22, 2023
df67561
[Unity][Analysis] TIR pattern kind analysis for multi-buffer write bl…
MasterJH5574 Feb 22, 2023
c0a591d
[Unity][Fix][Pass] FoldConstant with DCE in dataflow block (#14087)
MasterJH5574 Feb 22, 2023
a283a71
[Unity] Refactor Relax Build JIT UX (#14088)
tqchen Feb 22, 2023
d1997fd
[Unity][Relax] Set Shape Function to Be Host Function (#14090)
zxybazh Feb 22, 2023
4ca7107
[Unity] Fix typo in the comment (#14096)
vinx13 Feb 22, 2023
fc5981b
[Unity] Lower `shape_of` to a builtin (#14093)
YuchenJin Feb 22, 2023
3f4835c
[Unity] Relax Recursive function (#14092)
yongwww Feb 23, 2023
4d72daf
[Unity][Layout] Add layout transformation analysis for PrimFunc (#14066)
psrivas2 Feb 23, 2023
3f12d4d
[Unity] Remove attributes of relax.print, assert and unique (#14101)
yongwww Feb 23, 2023
d3a0e98
[Unity][BYOC]Add relax backend pattern registry (#14106)
yelite Feb 24, 2023
cc5292c
[Unity] Update tests again to adapt to latest TVMScript syntax (#14115)
Ubospica Feb 24, 2023
cfce06f
[Unity][Fix] Fix bug in MergeCompositeFunctions (#14117)
Ubospica Feb 24, 2023
82578c3
[Unity][BlockBuilder] Add `name_hint` argument for `emit` and `emit_o…
SiriusNEO Feb 25, 2023
678d01d
[Unity][WEB] Relax vm on web runtime (#14131)
tqchen Feb 25, 2023
e62169c
[Unity] Add Global info (#14132)
jinhongyii Feb 26, 2023
d7a6285
[Unity][BYOC] Add transposed matmul support to Relax CUTLASS BYOC (#1…
yelite Feb 27, 2023
ff21d66
[Unity][TVMScript] emit_te sugar (#14123)
yongwww Feb 27, 2023
15ba19f
[Unity][BYOC] Assign group to unused bindings and ignroe PrimFunc (#1…
vinx13 Feb 27, 2023
2d0c2e9
[Unity][TVMScript] Multiple return support in Relax
yongwww Feb 24, 2023
752c253
Use ReturnGlobalInfo
yongwww Feb 28, 2023
65b53e8
update printer
yongwww Mar 1, 2023
e32d773
Remove null_expr
yongwww Mar 1, 2023
bfe9082
update
yongwww Mar 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,14 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/driver/*.cc
src/support/*.cc
src/script/*.cc
src/relax/ir/*.cc
src/relax/op/*.cc
src/relax/analysis/*.cc
src/relax/transform/*.cc
src/relax/backend/vm/*.cc
src/relax/backend/task_extraction.cc
src/relax/backend/pattern_registry.cc
src/relax/utils.cc
)

tvm_file_glob(GLOB CODEGEN_SRCS
Expand Down Expand Up @@ -335,6 +343,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/vm/*.cc
src/runtime/minrpc/*.cc
src/runtime/relax_vm/*.cc
)

if(BUILD_FOR_HEXAGON)
Expand Down
253 changes: 253 additions & 0 deletions apps/relax_examples/e2e_auto_tir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import os
import csv
import json
import argparse
import logging
from typing import Dict
import numpy as np # type: ignore

import tvm
from tvm import relay, relax, runtime, transform
from tvm.ir.module import IRModule
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.relax.testing import relay_translator
from tvm.target.target import Target


def _parse_args():
args = argparse.ArgumentParser()
args.add_argument(
"--workload",
type=str,
required=True,
)
args.add_argument(
"--input-shape",
type=str,
required=True,
)
args.add_argument(
"--target",
type=str,
required=True,
)
args.add_argument(
"--num-trials",
type=int,
required=True,
)
args.add_argument(
"--rpc-host",
type=str,
default=None,
)
args.add_argument(
"--rpc-port",
type=int,
default=None,
)
args.add_argument(
"--rpc-key",
type=str,
default=None,
)
args.add_argument(
"--work-dir",
type=str,
required=True,
)
args.add_argument(
"--cache-dir",
type=str,
default=None,
)
args.add_argument(
"--rpc-timeout-sec",
type=int,
default=180,
)
args.add_argument("--num-measurement-repeats", type=int, default=5)
args.add_argument("--num-measurements", type=int, default=10)
args.add_argument("--results-file", type=str, required=False, default=None)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
parsed.input_shape = json.loads(parsed.input_shape)
if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu":
parsed.alloc_repeat = 3
else:
parsed.alloc_repeat = 1
if parsed.rpc_host and parsed.rpc_port and parsed.rpc_key:
parsed.rpc_config = ms.runner.RPCConfig(
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=parsed.rpc_timeout_sec,
)
parsed.workers = parsed.rpc_config.count_num_servers(allow_missing=False)
else:
# check all rpc configs are None
assert (
(parsed.rpc_host is None) and (parsed.rpc_port is None) and (parsed.rpc_key is None)
), "Please set all 'rpc_host', 'rpc_port' and 'rpc_key' to use PRC server"
parsed.rpc_config = None
parsed.workers = 1
return parsed


logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
ARGS = _parse_args()


def apply_opt_before_tuning(
relay_mod: IRModule, params: Dict[str, runtime.NDArray], target: Target
):
with transform.PassContext(opt_level=3):
main_func = relay_mod["main"]
bind_main_func = relay.build_module.bind_params_by_name(main_func, params)
relay_mod = IRModule.from_expr(bind_main_func)
relay_mod = relay.transform.SimplifyInference()(relay_mod)
relay_mod = relay.transform.FoldConstant()(relay_mod)
relay_mod = relay.transform.FoldScaleAxis()(relay_mod)
relay_mod = relay.transform.CanonicalizeOps()(relay_mod)
relay_mod = relay.transform.AlterOpLayout()(relay_mod)
relay_mod = relay.transform.FoldConstant()(relay_mod)

relax_mod = relay_translator.from_relay(relay_mod["main"], target=target)
relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod)
relax_mod = relax.transform.FuseOps()(relax_mod)
relax_mod = relax.transform.FuseTIR()(relax_mod)
return relax_mod


def f_measurement(
rt_mod: runtime.Module, device: runtime.ndarray.Device, input_data: Dict[str, runtime.NDArray]
):
vm = relax.VirtualMachine(rt_mod, device=device)
vm.save_function("main", "measure_func", **input_data, include_return=False)
evaluator = vm.time_evaluator(
func_name="measure_func",
dev=device,
repeat=ARGS.num_measurement_repeats,
number=ARGS.num_measurements,
min_repeat_ms=500,
)
return evaluator()


def get_runner():
runner_config = {
"evaluator_config": ms.runner.EvaluatorConfig(
number=3,
repeat=1,
min_repeat_ms=100,
enable_cpu_cache_flush=False,
),
"alloc_repeat": ARGS.alloc_repeat,
}
if ARGS.rpc_config:
runner = ms.runner.RPCRunner(
rpc_config=ARGS.rpc_config, max_workers=ARGS.workers, **runner_config
)
else:
runner = ms.runner.LocalRunner(**runner_config)

return runner


def main():
relay_mod, params, (input_name, input_shape, input_dtype) = get_network(
ARGS.workload,
ARGS.input_shape,
cache_dir=ARGS.cache_dir,
)
input_info = {input_name: input_shape}
input_data = {}
for input_name, input_shape in input_info.items():
print(f" input_name: {input_name}")
print(f" input_shape: {input_shape}")
print(f" input_dtype: {input_dtype}")

# translate the ResNet model from Relay to Relax
relax_mod = apply_opt_before_tuning(relay_mod, params, target=ARGS.target)
assert isinstance(relax_mod, tvm.IRModule)

db = ms.relax_integration.tune_relax(
mod=relax_mod,
target=ARGS.target,
params=params,
num_trials_per_iter=64,
max_trials_per_task=ARGS.num_trials,
max_trials_global=ARGS.num_trials,
runner=get_runner(),
work_dir=ARGS.work_dir,
)
executable = ms.relax_integration.compile_relax(
db,
mod=relax_mod,
target=ARGS.target,
params=params,
)

for input_name, input_shape in input_info.items():
if input_dtype.startswith("float"):
input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype)
else:
input_data[input_name] = np.random.randint(
low=0, high=10000, size=input_shape, dtype=input_dtype
)

# for documentation purposes
start_time = datetime.datetime.now()

if ARGS.rpc_config:
result = run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=executable.mod,
dev_type=ARGS.target.kind.name,
args=input_data,
continuation=f_measurement,
)
else:
dev = tvm.device(ARGS.target.kind.name)
result = f_measurement(executable.mod, dev, input_data)

print(result)

if not ARGS.results_file:
return

out_path = os.path.abspath(os.path.expanduser(ARGS.results_file))
with open(out_path, "w") as out_file:
writer = csv.writer(out_file)
# write experiment parameters at the top as a record
writer.writerow(["start", str(start_time)])
writer.writerow(["workload", ARGS.workload])
writer.writerow(["input_shape", ARGS.input_shape])
writer.writerow(["target", ARGS.target])
writer.writerow(["num_measurement_repeats", ARGS.num_measurement_repeats])
for res in result.results:
writer.writerow([str(res)])


if __name__ == "__main__":
main()
57 changes: 57 additions & 0 deletions apps/relax_examples/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Example code on creating, compiling, and running an MLP model in relax


import tvm
from tvm import relax, tir, topi
import numpy as np


def build_mlp(data, weight):
bb = relax.BlockBuilder()

with bb.function("mlp", [data, weight]):
gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False)
gv1 = bb.emit_te(topi.nn.relu, gv0)
bb.emit_func_output(gv1)

mod = bb.get()
return mod


if __name__ == "__main__":
# symbolic dimensions
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
# create data and weight variables
data = relax.Var("data", relax.TensorStructInfo([n, m], "float32"))
weight = relax.Var("weight", relax.TensorStructInfo([m, n], "float32"))

# construct a mlp model
mod = build_mlp(data, weight)

# build and create vm executor
target = tvm.target.Target("llvm", host="llvm")
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())

# run the mlp model on relax vm
data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
res = vm["mlp"](data, weight)
print(res)
Loading