From 304048c33956dddb5027fec26541d57f903d8ca2 Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Thu, 17 Nov 2022 17:02:11 -0800 Subject: [PATCH] Fix after rebase, and reorganize the TVMScript folder structure. Co-authored-by: Junru Shao Co-authored-by: Siyuan Feng --- Jenkinsfile | 9 +- include/tvm/relax/tuning_api.h | 7 + include/tvm/tir/analysis.h | 4 - python/tvm/script/__init__.py | 14 +- python/tvm/script/_parser/core/dispatch.py | 63 ---- python/tvm/script/_parser/core/entry.py | 46 --- python/tvm/script/_parser/core/evaluator.py | 284 ---------------- python/tvm/script/_parser/core/parser.py | 309 ------------------ python/tvm/script/_parser/ir/__init__.py | 21 -- python/tvm/script/_parser/ir/entry.py | 48 --- python/tvm/script/_parser/ir/parser.py | 68 ---- python/tvm/script/_parser/tir/__init__.py | 24 -- python/tvm/script/_parser/tir/entry.py | 87 ----- python/tvm/script/_parser/tir/operation.py | 84 ----- python/tvm/script/_parser/tir/parser.py | 268 --------------- python/tvm/script/parser/__init__.py | 3 +- python/tvm/script/parser/core/entry.py | 3 + python/tvm/script/parser/core/parser.py | 48 ++- python/tvm/script/parser/ir/parser.py | 40 ++- .../{_parser => parser}/relax/__init__.py | 0 .../script/{_parser => parser}/relax/entry.py | 3 +- .../{_parser => parser}/relax/parser.py | 0 python/tvm/script/parser/tir/operation.py | 12 +- python/tvm/script/parser_v1/parser.py | 1 - python/tvm/script/parser_v1/tir/intrin.py | 1 - python/tvm/te/operation.py | 7 +- src/relax/transform/fuse_tir.cc | 2 +- src/relax/transform/meta_schedule.cc | 2 +- src/relax/transform/tuning_api/database.cc | 3 +- src/relay/backend/utils.cc | 2 +- src/te/operation/create_primfunc.cc | 20 +- src/te/operation/create_primfunc.h | 6 +- .../test_hexagon/test_relax_integration.py | 4 +- tests/python/relax/test_tvmscript_parser.py | 8 +- 34 files changed, 136 insertions(+), 1365 deletions(-) delete mode 100644 python/tvm/script/_parser/core/dispatch.py delete mode 100644 python/tvm/script/_parser/core/entry.py delete mode 100644 python/tvm/script/_parser/core/evaluator.py delete mode 100644 python/tvm/script/_parser/core/parser.py delete mode 100644 python/tvm/script/_parser/ir/__init__.py delete mode 100644 python/tvm/script/_parser/ir/entry.py delete mode 100644 python/tvm/script/_parser/ir/parser.py delete mode 100644 python/tvm/script/_parser/tir/__init__.py delete mode 100644 python/tvm/script/_parser/tir/entry.py delete mode 100644 python/tvm/script/_parser/tir/operation.py delete mode 100644 python/tvm/script/_parser/tir/parser.py rename python/tvm/script/{_parser => parser}/relax/__init__.py (100%) rename python/tvm/script/{_parser => parser}/relax/entry.py (98%) rename python/tvm/script/{_parser => parser}/relax/parser.py (100%) diff --git a/Jenkinsfile b/Jenkinsfile index 9d6315f100..42648acd6c 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -50,14 +50,14 @@ import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> -ci_lint = 'tlcpack/ci-lint:20220925-060158-71f25b3d6' -ci_gpu = 'tlcpack/ci-gpu:20220925-060158-71f25b3d6' -ci_cpu = 'tlcpack/ci-cpu:20220925-060158-71f25b3d6' +ci_lint = 'tlcpack/ci-lint:20221025-182121-e41d0ed6e' +ci_gpu = 'tlcpack/ci-gpu:20221025-182121-e41d0ed6e' +ci_cpu = 'tlcpack/ci-cpu:20221025-182121-e41d0ed6e' ci_wasm = 'tlcpack/ci-wasm:v0.72' ci_i386 = 'tlcpack/ci-i386:v0.75' ci_qemu = 'tlcpack/ci-qemu:v0.11' ci_arm = 'tlcpack/ci-arm:v0.08' -ci_hexagon = 'tlcpack/ci-hexagon:20220925-060158-71f25b3d6' +ci_hexagon = 'tlcpack/ci-hexagon:20221025-182121-e41d0ed6e' // <--- End of regex-scanned config. // Parameters to allow overriding (in Jenkins UI), the images @@ -372,4 +372,3 @@ stage('Build and Test') { Utils.markStageSkippedForConditional('BUILD: CPU') } } - diff --git a/include/tvm/relax/tuning_api.h b/include/tvm/relax/tuning_api.h index d302b632b5..dc4f5715d2 100644 --- a/include/tvm/relax/tuning_api.h +++ b/include/tvm/relax/tuning_api.h @@ -297,6 +297,13 @@ class TuningRecord : public runtime::ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode); }; +/*! \brief The equality check for Workload */ +struct WorkloadEqual { + bool operator()(const meta_schedule::Workload& a, const meta_schedule::Workload& b) const { + return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod); + } +}; + /* \brief The abstract interface of database. */ class DatabaseNode : public runtime::Object { public: diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index d62883056d..e9796eca65 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -260,10 +260,6 @@ const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var); * \return The anchor block if found, nullptr otherwise. */ const tir::BlockNode* FindAnchorBlock(const IRModule& mod); - * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. - * \param func The PrimFunc to be analyzed. - * \return The Op Pattern Kind. - */ // Pass variants of verification analysis // directly throws RuntimeError when verification fails. diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 8f776bb951..3274859eb9 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""TVM Script APIs of TVM Python Package, aimed to support TIR""" -from . import _parser, parser_v1 +"""TVM Script APIs of TVM Python Package""" +from . import parser, parser_v1 ############# -from ._parser import ir -from ._parser import ir_module -from ._parser import parse as from_source_v2 -from ._parser import tir -from ._parser import relax +from .parser import ir +from .parser import ir_module +from .parser import parse as from_source_v2 +from .parser import tir +from .parser import relax ############# from .parser_v1 import from_source as from_source_v1 diff --git a/python/tvm/script/_parser/core/dispatch.py b/python/tvm/script/_parser/core/dispatch.py deleted file mode 100644 index f10b90961a..0000000000 --- a/python/tvm/script/_parser/core/dispatch.py +++ /dev/null @@ -1,63 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type - -from .doc import AST - -if TYPE_CHECKING: - from .parser import Parser - - -ParseMethod = Callable[["Parser", AST], None] -ParseVTable: Dict[Tuple[str, str], ParseMethod] = {} - -OpMethod = Callable[..., Any] -OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {} - - -def register(token: str, type_name: str): - """Register a method for a dispatch token and type name""" - - def f(method: ParseMethod): - ParseVTable[(token, type_name)] = method - - return f - - -def get( - token: str, - type_name: str, - default: Optional[ParseMethod] = None, -) -> Optional[ParseMethod]: - return ParseVTable.get((token, type_name), default) - - -def register_op(ty: Type, op: AST, operand_index: int): # pylint: disable=invalid-name - def f(method: OpMethod): - OpVTable[(ty, op, operand_index)] = method - - return f - - -def get_op( # pylint: disable=invalid-name - ty: Type, - op: Type, - operand_index: int, - default: Optional[OpMethod] = None, -) -> Optional[OpMethod]: - return OpVTable.get((ty, op, operand_index), default) diff --git a/python/tvm/script/_parser/core/entry.py b/python/tvm/script/_parser/core/entry.py deleted file mode 100644 index afd3cb5027..0000000000 --- a/python/tvm/script/_parser/core/entry.py +++ /dev/null @@ -1,46 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -"""The entry point of TVM parser.""" -from typing import Any, Union - -from ...ir_builder import IRBuilder -from . import doc -from .diagnostics import Source -from .parser import Parser - - -def parse(program: Union[doc.AST, Any, str], extra_vars=None): - if extra_vars is None: - from tvm.script._parser import ir # pylint: disable=import-outside-toplevel - from tvm.script._parser import relax # pylint: disable=import-outside-toplevel - from tvm.script._parser import tir # pylint: disable=import-outside-toplevel - - extra_vars = { - "I": ir, - "ir": ir, - "T": tir, - "tir": tir, - "relax": relax, - "R": relax, - } - - source = Source(program) - parser = Parser(source) - with IRBuilder() as builder: - parser.parse(extra_vars=extra_vars) - return builder.get() diff --git a/python/tvm/script/_parser/core/evaluator.py b/python/tvm/script/_parser/core/evaluator.py deleted file mode 100644 index 0c2ccee48a..0000000000 --- a/python/tvm/script/_parser/core/evaluator.py +++ /dev/null @@ -1,284 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -"""AST Evaluation""" -import ast -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union - -from . import dispatch, doc - -if TYPE_CHECKING: - from .parser import Parser - -DEFAULT_OP: Dict[Type, Callable[..., Any]] = { - doc.Add: lambda a, b: a + b, - doc.Sub: lambda a, b: a - b, - doc.Mult: lambda a, b: a * b, - doc.Div: lambda a, b: a / b, - doc.FloorDiv: lambda a, b: a // b, - doc.Mod: lambda a, b: a % b, - doc.LShift: lambda a, b: a << b, - doc.RShift: lambda a, b: a >> b, - doc.BitOr: lambda a, b: a | b, - doc.BitXor: lambda a, b: a ^ b, - doc.BitAnd: lambda a, b: a & b, - doc.MatMult: lambda a, b: a @ b, - # fmt: off - doc.Pow: lambda a, b: a**b, - # fmt: on - doc.Eq: lambda a, b: a == b, - doc.NotEq: lambda a, b: a != b, - doc.Lt: lambda a, b: a < b, - doc.LtE: lambda a, b: a <= b, - doc.Gt: lambda a, b: a > b, - doc.GtE: lambda a, b: a >= b, - doc.Is: lambda a, b: a is b, - doc.IsNot: lambda a, b: a is not b, - doc.In: lambda a, b: a in b, - doc.NotIn: lambda a, b: a not in b, - doc.And: lambda a, b: a and b, - doc.Or: lambda a, b: a or b, - doc.Invert: lambda a: ~a, - doc.Not: lambda a: not a, - doc.UAdd: lambda a: +a, - doc.USub: lambda a: -a, -} - - -class ExprEvaluator: - - parser: "Parser" - value_table: Dict[str, Any] - new_value_count: int - - def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None: - super().__init__() - self.parser = parser - self.value_table = value_table - self.new_value_count = 0 - - @staticmethod - def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any: - self = ExprEvaluator(parser, value_table) - result = self._visit(node) # pylint: disable=protected-access - if isinstance(result, doc.Name): - if result.id not in self.value_table: - self.parser.report_error(result, f"Undefined variable: {result.id}") - return self.value_table[result.id] - if isinstance(result, doc.Constant): - return result.value - raise TypeError(f"Unexpected result type: {type(result)}") - - def _add_intermediate_result(self, value: Any) -> doc.Name: - name = f"__tvm_tmp_value_{self.new_value_count}" - self.new_value_count += 1 - self.value_table[name] = value - lineno = 0 - col_offset = 0 - return doc.Name( - id=name, - ctx=doc.Load( - lineno=lineno, - col_offset=col_offset, - end_lineno=None, - end_col_offset=None, - ), - lineno=lineno, - col_offset=col_offset, - end_lineno=None, - end_col_offset=None, - ) - - def _visit(self, node: doc.AST) -> Any: - if isinstance(node, list): - return [self._visit(n) for n in node] - if isinstance(node, tuple): - return tuple(self._visit(n) for n in node) - assert isinstance(node, doc.AST) - if isinstance(node, doc.Name): - if node.id not in self.value_table: - self.parser.report_error(node, f"Undefined variable: {node.id}") - return node - if isinstance( - node, - ( - doc.Constant, - doc.expr_context, - doc.operator, - doc.boolop, - doc.unaryop, - doc.cmpop, - ), - ): - return node - if not isinstance(node, (doc.expr, doc.slice)): - return node - if isinstance(node, doc.Lambda): - return self._eval_lambda(node) - fields = {} - for field in node.__class__._FIELDS: # pylint: disable=protected-access - attr = getattr(node, field) - if isinstance(attr, (doc.AST, tuple, list)): - fields[field] = self._visit(attr) - else: - fields[field] = attr - try: - if isinstance(node, doc.BoolOp): - value = self._eval_bool_op(fields) - elif isinstance(node, doc.Compare): - value = self._eval_compare(fields) - elif isinstance(node, doc.UnaryOp): - value = self._eval_unary_op(fields) - elif isinstance(node, doc.BinOp): - value = self._eval_bin_op(fields) - elif isinstance(node, doc.Slice): - value = self._eval_slice(fields) - else: - value = self._eval_expr(node.__class__(**fields)) - except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, str(e)) - return self._add_intermediate_result(value) - - def _eval_lambda(self, node: doc.Lambda) -> Any: - try: - value = self._eval_expr(node) - except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, str(e)) - return self._add_intermediate_result(value) - - def _eval_bool_op(self, fields: Dict[str, Any]) -> Any: - op = fields["op"] - if not isinstance(op, (doc.And, doc.Or)): - raise TypeError(f"Unexpected operator: {op}") - value = self._eval_expr(fields["values"][0]) - for rhs in fields["values"][1:]: - value = _eval_op(op, values=[value, self._eval_expr(rhs)]) - return value - - def _eval_compare(self, fields: Dict[str, Any]) -> Any: - value = self._eval_expr(fields["left"]) - for op, rhs in zip(fields["ops"], fields["comparators"]): - value = _eval_op(op, values=[value, self._eval_expr(rhs)]) - return value - - def _eval_unary_op(self, fields: Dict[str, Any]) -> Any: - value = self._eval_expr(fields["operand"]) - value = _eval_op(fields["op"], values=[value]) - return value - - def _eval_bin_op(self, fields: Dict[str, Any]) -> Any: - return _eval_op( - fields["op"], - values=[ - self._eval_expr(fields["left"]), - self._eval_expr(fields["right"]), - ], - ) - - def _eval_slice(self, fields: Dict[str, Any]) -> Any: - lower, upper, step = fields["lower"], fields["upper"], fields["step"] - - lower = self._eval_expr(lower) if lower is not None else None - upper = self._eval_expr(upper) if upper is not None else None - step = self._eval_expr(step) if step is not None else None - - return slice(lower, upper, step) - - def _eval_expr(self, v: Any) -> Any: - return _eval_expr(v, self.value_table) - - -def eval_expr( - parser: "Parser", - node: Union[doc.expr, doc.Expression], - dict_globals: Optional[Dict[str, Any]], -) -> Any: - value_table = {} - if dict_globals is not None: - value_table.update(dict_globals) - return ExprEvaluator.eval(parser, value_table, node) - - -def eval_assign( - parser: "Parser", - target: doc.expr, - source: Any, -) -> Dict[str, Any]: - try: - return _eval_assign(target, source) - except Exception as e: # pylint: disable=broad-except,invalid-name - parser.report_error(target, f"Failed to evaluate assignment: {str(e)}") - raise - - -def _eval_expr( - node: Union[doc.expr, doc.Expression], - dict_globals: Optional[Dict[str, Any]], -) -> Any: - node = doc.from_doc(node) - if isinstance(node, ast.expr): - node = ast.Expression(body=node) - assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node) - if dict_globals is None: - dict_globals = {} - node = ast.fix_missing_locations(node) - exe = compile(node, filename="", mode="eval") - return eval(exe, dict_globals) # pylint: disable=eval-used - - -def _eval_op( - op: doc.AST, - values: List[Any], -): - op_type = type(op) # pylint: disable=protected-access - for i, v in enumerate(values): - v_type = getattr(type(v), "_dispatch_type", None) - if v_type is None: - continue - f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None) - if f is not None: - return f(*values) - return DEFAULT_OP[op_type](*values) - - -def _eval_assign( - target: doc.expr, - source: Any, -) -> Dict[str, Any]: - target = doc.from_doc(target) - assert isinstance(target, ast.expr) - RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name - rhs_var_name = RHS_VAR_NAME - dict_locals = {rhs_var_name: source} - mod = ast.fix_missing_locations( - ast.Module( - body=[ - ast.Assign( - targets=[target], - value=ast.Name( - id=rhs_var_name, - ctx=ast.Load(), - ), - ) - ], - type_ignores=[], - ) - ) - exe = compile(mod, filename="", mode="exec") - exec(exe, {}, dict_locals) # pylint: disable=exec-used - del dict_locals[rhs_var_name] - return dict_locals diff --git a/python/tvm/script/_parser/core/parser.py b/python/tvm/script/_parser/core/parser.py deleted file mode 100644 index 8f6850b454..0000000000 --- a/python/tvm/script/_parser/core/parser.py +++ /dev/null @@ -1,309 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -"""The core parser""" -from collections import defaultdict -from contextlib import contextmanager -from typing import Any, Callable, Dict, List, Optional, Set, Union - -from tvm._ffi.base import TVMError -from tvm.error import DiagnosticError - -from . import dispatch, doc -from .diagnostics import Diagnostics, Source -from .evaluator import eval_assign, eval_expr - -DEFAULT_VISIT = { - "Interactive", - "Module", - "Expression", - "Pass", -} - - -def _deferred(f: Callable[[], None]): - @contextmanager - def context(): - try: - yield - finally: - f() - - return context() - - -def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument - pass - - -class VarTableFrame: - vars: Set[str] - - def __init__(self): - self.vars = set() - - def add(self, var: str): - if var in self.vars: - raise ValueError(f"Variable {var} already defined in current scope") - self.vars.add(var) - - def pop_all(self, fn_pop: Callable[[str], None]): - for var in self.vars: - fn_pop(var) - self.vars.clear() - - -class VarTable: - - frames: List[VarTableFrame] - name2value: Dict[str, List[Any]] - - def __init__(self): - self.frames = [] - self.name2value = defaultdict(list) - - def with_frame(self): - def pop_frame(): - frame = self.frames.pop() - frame.pop_all(lambda name: self.name2value[name].pop()) - - self.frames.append(VarTableFrame()) - return _deferred(pop_frame) - - def add(self, var: str, value: Any, allow_shadowing: bool = False): - # Skip if the key and value are equal to those in the var_table - if self.name2value[var] and self.name2value[var][-1] == value: - return - if allow_shadowing and var in self.frames[-1].vars: - # Shadowing - self.name2value[var][-1] = value - else: - self.frames[-1].add(var) - self.name2value[var].append(value) - - def get(self) -> Dict[str, Any]: - return {key: values[-1] for key, values in self.name2value.items() if values} - - def exist(self, value: Any): - for v in self.name2value.values(): - if v is value: - return True - return False - - -def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod: - def _wrapper(self: "Parser", node: doc.AST) -> None: - try: - return func(self, node) - except DiagnosticError: - raise - except Exception as e: # pylint: disable=broad-except,invalid-name - self.report_error(node, e) - raise - - return _wrapper - - -def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod: - for token in [self.dispatch_tokens[-1], "default"]: - func = dispatch.get(token=token, type_name=type_name, default=None) - if func is not None: - return _dispatch_wrapper(func) - return _dispatch_wrapper(lambda self, node: self.generic_visit(node)) - - -class Parser(doc.NodeVisitor): - """The TVMScript parser""" - - diag: Diagnostics - dispatch_tokens: List[str] - var_table: VarTable - - def __init__(self, source: Source) -> None: - self.diag = Diagnostics(source) - self.dispatch_tokens = ["default"] - self.var_table = VarTable() - - def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: - if extra_vars is None: - extra_vars = {} - with self.var_table.with_frame(): - for k, v in extra_vars.items(): - self.var_table.add(k, v) - node = self.diag.source.as_ast() - self.visit(node) - - def get_dispatch_token(self, node: doc.FunctionDef) -> str: - if not isinstance(node, doc.FunctionDef): - self.report_error(node, "Only can get dispatch token for function.") - if not node.decorator_list: - self.report_error(node, "Function must be decorated") - # TODO: only the last decorator is parsed - decorator = self.eval_expr(node.decorator_list[-1]) - if not hasattr(decorator, "dispatch_token"): - self.report_error(node, "The parser does not understand the decorator") - return decorator.dispatch_token - - def with_dispatch_token(self, token: str): - def pop_token(): - self.dispatch_tokens.pop() - - self.dispatch_tokens.append(token) - return _deferred(pop_token) - - def eval_expr( - self, - node: Union[doc.Expression, doc.expr], - extra_vars: Optional[Dict[str, Any]] = None, - ) -> Any: - var_values = self.var_table.get() - if extra_vars is not None: - for k, v in extra_vars.items(): - var_values[k] = v - return eval_expr(self, node, var_values) - - def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]: - if isinstance(target, (doc.Tuple, doc.List)): - vars: Set[str] = set() # pylint: disable=redefined-builtin - for i in target.elts: - res = self._duplicate_lhs_check(i) - if isinstance(res, bool) and res: - return True - assert isinstance(res, set) - if vars & res: - return True - vars = vars.union(res) - return vars - elif isinstance(target, doc.Name): - return {target.id} - else: - self.report_error(target, "Invalid type in assign statement") - raise NotImplementedError - - def eval_assign( - self, - target: doc.expr, - source: Any, - bind_value: Callable[["Parser", doc.expr, str, Any], Any], - allow_shadowing: bool = False, - ) -> Dict[str, Any]: - if self._duplicate_lhs_check(target) is True: - self.report_error(target, "Duplicate vars assigned.") - var_values = eval_assign(self, target, source) - for k, v in var_values.items(): - var = bind_value(self, target, k, v) - self.var_table.add(k, var, allow_shadowing) - return var_values - - def report_error( - self, node: doc.AST, err: Union[Exception, str] - ) -> None: # pylint: disable=no-self-use - # Only take the last line of the error message - if isinstance(err, (TVMError, ValueError, TypeError)): - msg = list(filter(None, str(err).split("\n")))[-1] - else: - msg = str(err) - self.diag.error(node, msg) - - def visit(self, node: doc.AST) -> None: - if isinstance(node, (list, tuple)): - for item in node: - self.visit(item) - return - if not isinstance(node, doc.AST): - return - name = node.__class__.__name__.split(".")[-1] - if name in DEFAULT_VISIT: - func = self.generic_visit - else: - func = getattr(self, "visit_" + name, None) - if func is None: - raise NotImplementedError(f"Visitor of AST node is not implemented: {name}") - try: - func(node) - except DiagnosticError: - raise - except Exception as e: # pylint: disable=broad-except,invalid-name - self.report_error(node, str(e)) - raise - - def visit_body(self, node: List[doc.stmt]) -> Any: - for stmt in node: - self.visit(stmt) - - def visit_tvm_annotation(self, node: doc.expr) -> Any: - return _dispatch(self, "tvm_annotation")(self, node) - - def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name - token = self.get_dispatch_token(node) - current_token = self.dispatch_tokens[-1] - func = dispatch.get(token=token, type_name="FunctionDef", default=None) - if func is None: - self.report_error(node, "The parser does not understand the decorator") - pre_func = dispatch.get( - token=current_token, type_name="pre_token_switch", default=_do_nothing - ) - post_func = dispatch.get( - token=current_token, type_name="post_token_switch", default=_do_nothing - ) - pre_func(self, node) - _dispatch_wrapper(func)(self, node) - post_func(self, node) - - def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None: - token = self.get_dispatch_token(node) - with self.with_dispatch_token(token): - _dispatch(self, "tvm_declare_function")(self, node) - - def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name - func = dispatch.get(token="ir", type_name="ClassDef", default=None) - if func is None: - self.report_error(node, "The parser does not understand the decorator") - _dispatch_wrapper(func)(self, node) - - def visit_arguments(self, node: doc.arguments) -> Any: - return _dispatch(self, "arguments")(self, node) - - def visit_For(self, node: doc.For) -> Any: # pylint: disable=invalid-name - return _dispatch(self, "For")(self, node) - - def visit_While(self, node: doc.While) -> Any: # pylint: disable=invalid-name - return _dispatch(self, "While")(self, node) - - def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name - return _dispatch(self, "With")(self, node) - - def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name - return _dispatch(self, "Assign")(self, node) - - def visit_Expr(self, node: doc.Expr) -> Any: # pylint: disable=invalid-name - return _dispatch(self, "Expr")(self, node) - - def visit_If(self, node: doc.If) -> Any: # pylint: disable=invalid-name - return _dispatch(self, "If")(self, node) - - def visit_AnnAssign(self, node: doc.AnnAssign) -> Any: # pylint: disable=invalid-name - return _dispatch(self, "AnnAssign")(self, node) - - def visit_AugAssign(self, node: doc.AugAssign) -> Any: # pylint: disable=invalid-name - return _dispatch(self, "AugAssign")(self, node) - - def visit_Assert(self, node: doc.Assert) -> Any: # pylint: disable=invalid-name - return _dispatch(self, "Assert")(self, node) - - def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name - return _dispatch(self, "Return")(self, node) diff --git a/python/tvm/script/_parser/ir/__init__.py b/python/tvm/script/_parser/ir/__init__.py deleted file mode 100644 index 8cf9b50665..0000000000 --- a/python/tvm/script/_parser/ir/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -from . import parser as _parser -from .entry import ir_module, is_defined_in_class - -__all__ = ["ir_module", "is_defined_in_class"] diff --git a/python/tvm/script/_parser/ir/entry.py b/python/tvm/script/_parser/ir/entry.py deleted file mode 100644 index e0a0213cd1..0000000000 --- a/python/tvm/script/_parser/ir/entry.py +++ /dev/null @@ -1,48 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -import inspect -from typing import Type - -from tvm.ir import IRModule - -from .._core import parse, utils - - -def is_defined_in_class(frames): - if len(frames) > 2: - maybe_class_frame = frames[2] - statement_list = maybe_class_frame[4] - if statement_list is None: - return False - first_statement = statement_list[0] - line = first_statement.strip() - if line.startswith("class "): - return True - if line.startswith("@") and "ir_module" in line: - return True - return False - - -def ir_module(f: Type) -> IRModule: - if not inspect.isclass(f): - raise TypeError(f"Expect a class, but got: {f}") - - return parse(f, utils.inspect_class_capture(f)) - - -setattr(ir_module, "dispatch_token", "ir") diff --git a/python/tvm/script/_parser/ir/parser.py b/python/tvm/script/_parser/ir/parser.py deleted file mode 100644 index b6a8cab060..0000000000 --- a/python/tvm/script/_parser/ir/parser.py +++ /dev/null @@ -1,68 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -from typing import Optional, Tuple - -from tvm.ir import PrimExpr, PrimType, RelayExpr, Type - -from ...ir_builder import ir as I -from .._core import Parser, dispatch, doc - - -def eval_func_type_shape( - self: Parser, node: doc.FunctionDef -) -> Tuple[Optional[Type], Optional[RelayExpr]]: - token = self.get_dispatch_token(node) - with self.with_dispatch_token(token): - result = self.visit_tvm_annotation(node.returns) - if result is None: - return None, None - elif isinstance(result, tuple) and len(result) == 2: - # relax dialect - return result - elif isinstance(result, PrimExpr): - # tir dialect - return PrimType(result.dtype), None - else: - raise TypeError(f"Unsupported annotation type: {result}") - - -@dispatch.register(token="ir", type_name="ClassDef") -def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: - with self.var_table.with_frame(): - with I.ir_module(): - for stmt in node.body: - if isinstance(stmt, doc.FunctionDef): - self.visit_tvm_declare_function(stmt) - with self.with_dispatch_token("ir"): - self.visit_body(node.body) - - -@dispatch.register(token="ir", type_name="Assign") -def _visit_assign(_self: Parser, _node: doc.Assign) -> None: - pass - - -@dispatch.register(token="ir", type_name="Expr") -def _visit_expr(_self: Parser, _node: doc.Expr) -> None: - pass - - -@dispatch.register(token="default", type_name="tvm_declare_function") -def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: - global_var = I.decl_function(node.name) - self.var_table.add(node.name, global_var) diff --git a/python/tvm/script/_parser/tir/__init__.py b/python/tvm/script/_parser/tir/__init__.py deleted file mode 100644 index 930764f73d..0000000000 --- a/python/tvm/script/_parser/tir/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -from ...ir_builder.tir import * # pylint: disable=redefined-builtin -from ...ir_builder.tir import ir as _tir -from . import operation as _operation -from . import parser as _parser -from .entry import Buffer, Ptr, prim_func - -__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"] diff --git a/python/tvm/script/_parser/tir/entry.py b/python/tvm/script/_parser/tir/entry.py deleted file mode 100644 index 07bd75f351..0000000000 --- a/python/tvm/script/_parser/tir/entry.py +++ /dev/null @@ -1,87 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -import inspect -from typing import Callable, Union - -from tvm.tir import Buffer, PrimFunc - -from ...ir_builder.tir import buffer_decl, ptr -from .._core import parse, utils -from ..ir import is_defined_in_class - - -def prim_func(f: Callable) -> Union[PrimFunc, Callable]: - if not inspect.isfunction(f): - raise TypeError(f"Expect a function, but got: {f}") - if is_defined_in_class(inspect.stack()): - return f - return parse(f, utils.inspect_function_capture(f)) - - -setattr(prim_func, "dispatch_token", "tir") - - -class BufferProxy: - def __call__( - self, - shape, - dtype="float32", - data=None, - strides=None, - elem_offset=None, - scope="global", - align=0, - offset_factor=0, - buffer_type="", - axis_separators=None, - ) -> Buffer: - return buffer_decl( - shape, - dtype=dtype, - data=data, - strides=strides, - elem_offset=elem_offset, - scope=scope, - align=align, - offset_factor=offset_factor, - buffer_type=buffer_type, - axis_separators=axis_separators, - ) - - def __getitem__(self, keys) -> Buffer: - if not isinstance(keys, tuple): - return self(keys) - if len(keys) >= 2 and not isinstance(keys[1], str): - return self(keys) - return self(*keys) # pylint: disable=no-member # type: ignore - - -class PtrProxy: - def __call__(self, dtype, storage_scope="global"): - if callable(dtype): - dtype = dtype().dtype - return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore - - def __getitem__(self, keys): - if not isinstance(keys, tuple): - return self(keys) - return self(*keys) - - -Buffer = BufferProxy() # pylint: disable=invalid-name -Ptr = PtrProxy() # pylint: disable=invalid-name diff --git a/python/tvm/script/_parser/tir/operation.py b/python/tvm/script/_parser/tir/operation.py deleted file mode 100644 index 87fb9406ae..0000000000 --- a/python/tvm/script/_parser/tir/operation.py +++ /dev/null @@ -1,84 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -from typing import Type - -from tvm import tir -from tvm.tir import IntImm - -from .._core import OpMethod, doc, register_op - - -def _register_expr_op(ty: Type): # pylint: disable=invalid-name - ty._dispatch_type = ty # pylint: disable=protected-access - - def _and(a, b): - if isinstance(a, bool): - a = IntImm("bool", a) - if isinstance(b, bool): - b = IntImm("bool", b) - return tir.And(a, b) - - def _or(a, b): - if isinstance(a, bool): - a = IntImm("bool", a) - if isinstance(b, bool): - b = IntImm("bool", b) - return tir.Or(a, b) - - def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name - register_op(ty, op, i)(m) - - for i in [0, 1]: - # Case 1. binop - r(doc.Add, i, lambda a, b: a + b) - r(doc.Sub, i, lambda a, b: a - b) - r(doc.Mult, i, lambda a, b: a * b) - r(doc.Div, i, lambda a, b: a / b) - r(doc.FloorDiv, i, lambda a, b: a // b) - r(doc.Mod, i, lambda a, b: a % b) - r(doc.LShift, i, lambda a, b: a << b) - r(doc.RShift, i, lambda a, b: a >> b) - r(doc.BitOr, i, lambda a, b: a | b) - r(doc.BitXor, i, lambda a, b: a ^ b) - r(doc.BitAnd, i, lambda a, b: a & b) - # doc.MatMult <-- not implemented - # doc.Pow <-- not implemented - # Case 2. cmpop - r(doc.Eq, i, tir.EQ) - r(doc.NotEq, i, tir.NE) - r(doc.Lt, i, tir.LT) - r(doc.LtE, i, tir.LE) - r(doc.Gt, i, tir.GT) - r(doc.GtE, i, tir.GE) - # doc.Is <-- not implemented - # doc.IsNot <-- not implemented - # doc.In <-- not implemented - # doc.NotIn <-- not implemented - # Case 3. boolop - r(doc.And, i, _and) - r(doc.Or, i, _or) - for i in [0]: - # Case 4. unaryop - r(doc.Invert, i, lambda a: ~a) - r(doc.Not, i, tir.Not) - r(doc.UAdd, i, lambda a: +a) - r(doc.USub, i, lambda a: -a) - - -_register_expr_op(tir.PrimExpr) -_register_expr_op(tir.IterVar) diff --git a/python/tvm/script/_parser/tir/parser.py b/python/tvm/script/_parser/tir/parser.py deleted file mode 100644 index 032555187f..0000000000 --- a/python/tvm/script/_parser/tir/parser.py +++ /dev/null @@ -1,268 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -import contextlib -from functools import partial -from typing import Any - -from tvm.ir import PrimType -from tvm.tir import Buffer, IterVar, PrimExpr, Var - -from ...ir_builder import tir as T -from ...ir_builder.base import IRBuilder -from ...ir_builder.base import IRBuilderFrame as Frame -from .._core import Parser, dispatch, doc - - -def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: - if isinstance(value, (list, tuple)): - for i, v in enumerate(value): - bind_with_value(self, node, f"{var_name}_{i}", v) - return value - elif isinstance(value, (Buffer, Var)): - IRBuilder.name(var_name, value) - return value - else: - self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement") - raise NotImplementedError - - -def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: - if isinstance(value, (list, tuple)): - for i, v in enumerate(value): - bind_with_value(self, node, f"{var_name}_{i}", v) - return value - elif isinstance(value, Var): - IRBuilder.name(var_name, value) - return value - else: - self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement") - raise NotImplementedError - - -def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any: - if isinstance(value, T.inline): - return value.value - elif isinstance(value, (list, tuple)): - for i, v in enumerate(value): - bind_with_value(self, _node, f"{var_name}_{i}", v) - return value - elif isinstance(value, Frame): - value.add_callback(partial(value.__exit__, None, None, None)) - res = value.__enter__() - IRBuilder.name(var_name, res) - return res - elif isinstance(value, (Buffer, IterVar)) or ( - isinstance(value, Var) and not self.var_table.exist(value) - ): - IRBuilder.name(var_name, value) - return value - elif isinstance(value, PrimExpr): - var = T.var(value.dtype) - IRBuilder.name(var_name, var) - frame = T.let(var, value) - frame.add_callback(partial(frame.__exit__, None, None, None)) - frame.__enter__() - return var - return value - - -@dispatch.register(token="tir", type_name="For") -def visit_for(self: Parser, node: doc.For) -> None: - for_frame = self.eval_expr(node.iter) - if not isinstance(for_frame, T.frame.ForFrame): - self.report_error( - node.iter, - "Expect the for loop to be one of the following: " - "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", - ) - with self.var_table.with_frame(): - with for_frame as iters: - self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value) - self.visit_body(node.body) - - -@dispatch.register(token="tir", type_name="While") -def visit_while(self: Parser, node: doc.While) -> None: - with self.var_table.with_frame(): - cond = self.eval_expr(node.test) - with T.While(cond): - self.visit_body(node.body) - - -@dispatch.register(token="tir", type_name="Assign") -def visit_assign(self: Parser, node: doc.Assign) -> None: - if len(node.targets) != 1: - self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") - lhs = node.targets[0] - rhs = self.eval_expr(node.value) - if isinstance(lhs, doc.Subscript): - if isinstance(lhs.slice, doc.Tuple): - indices = [] - for index in lhs.slice.elts: - indices.append(self.eval_expr(index)) - else: - indices = [self.eval_expr(lhs.slice)] - T.buffer_store(self.eval_expr(lhs.value), rhs, indices) - else: - self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) - - -@dispatch.register(token="tir", type_name="AugAssign") -def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None: - lhs_pos = ( - node.target.lineno, - node.target.col_offset, - node.target.end_lineno, - node.target.end_col_offset, - ) - rhs_pos = ( - node.value.lineno, - node.value.col_offset, - node.value.end_lineno, - node.value.end_col_offset, - ) - node.target.ctx = doc.Load(*lhs_pos) - with self.var_table.with_frame(): - lhs_name = "__tvm_tmp_value_aug_assign_lhs" - rhs_name = "__tvm_tmp_value_aug_assign_rhs" - lhs_expr = self.eval_expr(node.target) - rhs_expr = self.eval_expr(node.value) - self.var_table.add(lhs_name, lhs_expr) - self.var_table.add(rhs_name, rhs_expr) - op = doc.BinOp( - doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos), - node.op, - doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos), - *lhs_pos, - ) - rhs = self.eval_expr(op) - lhs = node.target - lhs.ctx = doc.Store(*lhs_pos) - if isinstance(lhs, doc.Subscript): - if isinstance(lhs.slice, doc.Tuple): - indices = [] - for index in lhs.slice.elts: - indices.append(self.eval_expr(index)) - else: - indices = [self.eval_expr(lhs.slice)] - T.buffer_store(self.eval_expr(lhs.value), rhs, indices) - else: - self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) - - -@dispatch.register(token="tir", type_name="AnnAssign") -def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: - lhs = node.target - rhs = self.eval_expr(node.value) - ann_var = self.visit_tvm_annotation(node.annotation) - if not isinstance(ann_var, Var): - self.report_error(node.annotation, "Annotation should be Var") - self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) - frame = T.let(ann_var, rhs) - frame.add_callback(partial(frame.__exit__, None, None, None)) - frame.__enter__() - - -@dispatch.register(token="tir", type_name="With") -def visit_with(self: Parser, node: doc.With) -> None: - with contextlib.ExitStack() as stack: - stack.enter_context(self.var_table.with_frame()) - for item in node.items: - frame = self.eval_expr(item.context_expr) - if not isinstance(frame, Frame): - self.report_error( - item.context_expr, "Invalid context expression in the with-statement." - ) - rhs = stack.enter_context(frame) - if item.optional_vars is not None: - self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value) - self.visit_body(node.body) - - -@dispatch.register(token="tir", type_name="FunctionDef") -def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: - with self.var_table.with_frame(): - self.var_table.add("range", T.serial) - with T.prim_func(): - T.func_name(node.name) - if node.returns is not None: - ret_type = self.eval_expr(node.returns) - if callable(ret_type): - ret_type = PrimType(ret_type().dtype) - T.func_ret(ret_type) - with self.with_dispatch_token("tir"): - self.visit(node.args) - self.visit_body(node.body) - - -@dispatch.register(token="tir", type_name="arguments") -def visit_arguments(self: Parser, node: doc.arguments) -> None: - # TODO: handle different types of arguments: - # - vararg: arg | None - # - kwonlyargs: list[arg] - # - kw_defaults: list[expr | None] - # - kwarg: arg | None - # - defaults: list[expr] - # - posonlyargs: list[arg] - arg: doc.arg - for arg in node.args: - if arg.annotation is None: - self.report_error(arg, "Type annotation is required for function parameters.") - param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation)) - self.var_table.add(arg.arg, param) - - -@dispatch.register(token="tir", type_name="tvm_annotation") -def visit_tvm_annotation(self: Parser, node: doc.expr): - annotation = self.eval_expr(node) - if callable(annotation): - annotation = annotation() - return annotation - - -@dispatch.register(token="tir", type_name="Expr") -def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: - res = self.eval_expr(node.value) - if isinstance(res, Frame): - res.add_callback(partial(res.__exit__, None, None, None)) - res.__enter__() - - -@dispatch.register(token="tir", type_name="If") -def visit_if(self: Parser, node: doc.If) -> None: - with self.var_table.with_frame(): - with T.If(self.eval_expr(node.test)): - with T.Then(): - self.visit_body(node.body) - if node.orelse: - with T.Else(): - self.visit_body(node.orelse) - - -@dispatch.register(token="tir", type_name="Assert") -def visit_assert(self: Parser, node: doc.Assert) -> None: - cond = self.eval_expr(node.test) - msg = self.eval_expr(node.msg) - frame = T.Assert(cond, msg) - frame.add_callback(partial(frame.__exit__, None, None, None)) - frame.__enter__() - - -@dispatch.register(token="tir", type_name="Return") -def visit_return(self: Parser, node: doc.Return) -> None: - self.report_error(node, "Return is not allowed.") diff --git a/python/tvm/script/parser/__init__.py b/python/tvm/script/parser/__init__.py index 5161a2601c..678297799e 100644 --- a/python/tvm/script/parser/__init__.py +++ b/python/tvm/script/parser/__init__.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the Licens. """The parser""" -from . import _core, ir, tir +from . import _core, ir, tir, relax from ._core import parse from .ir import ir_module from .tir import prim_func +from .relax import function diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index bf6a118672..bb95d41171 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -42,6 +42,7 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) """ if extra_vars is None: from tvm.script.parser import ir # pylint: disable=import-outside-toplevel + from tvm.script.parser import relax # pylint: disable=import-outside-toplevel from tvm.script.parser import tir # pylint: disable=import-outside-toplevel extra_vars = { @@ -49,6 +50,8 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) "ir": ir, "T": tir, "tir": tir, + "relax": relax, + "R": relax, } source = Source(program) diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index c6d43f11cb..7ad7ad6b70 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -59,6 +59,10 @@ def context(): return context() +def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument + pass + + class VarTableFrame: """The variable table frame. A frame of variable table stores the variables created in one block or scope. @@ -255,6 +259,17 @@ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: node = self.diag.source.as_ast() self.visit(node) + def get_dispatch_token(self, node: doc.FunctionDef) -> str: + if not isinstance(node, doc.FunctionDef): + self.report_error(node, "Only can get dispatch token for function.") + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + return decorator.dispatch_token + def with_dispatch_token(self, token: str): """Add a new dispatching token as with statement. @@ -453,30 +468,33 @@ def visit_tvm_annotation(self, node: doc.expr) -> Any: """ return _dispatch(self, "tvm_annotation")(self, node) - def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name - """The general function definition visiting method. + def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name + """The general function definition visit method. Parameters ---------- node : doc.FunctionDef - The doc AST function definition node. - - Returns - ------- - res : Any - The visiting result. + The doc FunctionDef node. """ - if not node.decorator_list: - self.report_error(node, "Function must be decorated") - # TODO: only the last decorator is parsed - decorator = self.eval_expr(node.decorator_list[-1]) - if not hasattr(decorator, "dispatch_token"): - self.report_error(node, "The parser does not understand the decorator") - token = decorator.dispatch_token + token = self.get_dispatch_token(node) + current_token = self.dispatch_tokens[-1] func = dispatch.get(token=token, type_name="FunctionDef", default=None) if func is None: self.report_error(node, "The parser does not understand the decorator") + pre_func = dispatch.get( + token=current_token, type_name="pre_token_switch", default=_do_nothing + ) + post_func = dispatch.get( + token=current_token, type_name="post_token_switch", default=_do_nothing + ) + pre_func(self, node) _dispatch_wrapper(func)(self, node) + post_func(self, node) + + def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None: + token = self.get_dispatch_token(node) + with self.with_dispatch_token(token): + _dispatch(self, "tvm_declare_function")(self, node) def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name """The general class definition visiting method. diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index 9532e7e32c..d719ac90b6 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -15,25 +15,55 @@ # specific language governing permissions and limitations # under the License. """The base parser for ir module""" +from typing import Optional, Tuple +from tvm.ir import PrimExpr, PrimType, RelayExpr, Type from ...ir_builder import ir as I from .._core import Parser, dispatch, doc +def eval_func_type_shape( + self: Parser, node: doc.FunctionDef +) -> Tuple[Optional[Type], Optional[RelayExpr]]: + """evaluate function type and shape. + Parameters + ---------- + self : Parser + The visiting parser. + node : doc.FunctionDef + The doc FunctionDef node. + """ + token = self.get_dispatch_token(node) + with self.with_dispatch_token(token): + result = self.visit_tvm_annotation(node.returns) + if result is None: + return None, None + elif isinstance(result, tuple) and len(result) == 2: + # relax dialect + return result + elif isinstance(result, PrimExpr): + # tir dialect + return PrimType(result.dtype), None + else: + raise TypeError(f"Unsupported annotation type: {result}") + + @dispatch.register(token="ir", type_name="ClassDef") def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: """The class definition visiting method for ir module. - Parameters ---------- self : Parser The visiting parser. - node : doc.ClassDef The doc AST class definition node. """ + with self.var_table.with_frame(): with I.ir_module(): + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): self.visit_body(node.body) @@ -64,3 +94,9 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None: node : doc.ClassDef The doc AST expression node. """ + + +@dispatch.register(token="default", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: + global_var = I.decl_function(node.name) + self.var_table.add(node.name, global_var) diff --git a/python/tvm/script/_parser/relax/__init__.py b/python/tvm/script/parser/relax/__init__.py similarity index 100% rename from python/tvm/script/_parser/relax/__init__.py rename to python/tvm/script/parser/relax/__init__.py diff --git a/python/tvm/script/_parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py similarity index 98% rename from python/tvm/script/_parser/relax/entry.py rename to python/tvm/script/parser/relax/entry.py index c105884510..c0bb48a843 100644 --- a/python/tvm/script/_parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -29,7 +29,6 @@ from ...ir_builder.relax import TensorType, tensor from .._core import parse, utils -from ..ir import is_defined_in_class FType = _TypeVar("FType", bound=_Callable) @@ -37,7 +36,7 @@ def function(f: FType) -> Union[Function, FType]: if not inspect.isfunction(f): raise TypeError(f"Expect a function, but got: {f}") - if is_defined_in_class(inspect.stack()): + if utils.is_defined_in_class(inspect.stack(), f): return f return parse(f, utils.inspect_function_capture(f)) diff --git a/python/tvm/script/_parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py similarity index 100% rename from python/tvm/script/_parser/relax/parser.py rename to python/tvm/script/parser/relax/parser.py diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py index f0c04f47cd..ed8f07a063 100644 --- a/python/tvm/script/parser/tir/operation.py +++ b/python/tvm/script/parser/tir/operation.py @@ -46,12 +46,12 @@ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name for i in [0, 1]: # Case 1. binop - r(doc.Add, i, tir.Add) - r(doc.Sub, i, tir.Sub) - r(doc.Mult, i, tir.Mul) - r(doc.Div, i, tir.Div) - r(doc.FloorDiv, i, tir.FloorDiv) - r(doc.Mod, i, tir.FloorMod) + r(doc.Add, i, lambda a, b: a + b) + r(doc.Sub, i, lambda a, b: a - b) + r(doc.Mult, i, lambda a, b: a * b) + r(doc.Div, i, lambda a, b: a / b) + r(doc.FloorDiv, i, lambda a, b: a // b) + r(doc.Mod, i, lambda a, b: a % b) r(doc.LShift, i, lambda a, b: a << b) r(doc.RShift, i, lambda a, b: a >> b) r(doc.BitOr, i, lambda a, b: a | b) diff --git a/python/tvm/script/parser_v1/parser.py b/python/tvm/script/parser_v1/parser.py index 613a732139..1f8f71c271 100644 --- a/python/tvm/script/parser_v1/parser.py +++ b/python/tvm/script/parser_v1/parser.py @@ -38,7 +38,6 @@ from tvm.tir import buffer from tvm.tir.function import PrimFunc -from .. import _ffi_api from . import tir from .context_maintainer import ContextMaintainer from .diagnostics import TVMDiagnosticCtx diff --git a/python/tvm/script/parser_v1/tir/intrin.py b/python/tvm/script/parser_v1/tir/intrin.py index cfea0f8591..b998db3a8c 100644 --- a/python/tvm/script/parser_v1/tir/intrin.py +++ b/python/tvm/script/parser_v1/tir/intrin.py @@ -20,7 +20,6 @@ from typing import Any, List import tvm.tir -from tvm.tir import FloatImm from ....target import codegen from ..registry import register diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 2e66c9e0dc..4b3442889c 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -568,7 +568,9 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None): def create_prim_func( - ops: List[_tensor.Tensor], tir_var_list: List[tvm.tir.Var] = None + ops: List[_tensor.Tensor], + tir_var_list: List[tvm.tir.Var] = None, + index_dtype_override: Optional[str] = None, ) -> tvm.tir.PrimFunc: """Create a TensorIR PrimFunc from tensor expression @@ -576,6 +578,7 @@ def create_prim_func( ---------- ops : List[Tensor] The source expression. + tir_var_list: List[Var] TIR variables to add as parameters to generated PrimFunc @@ -623,7 +626,7 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: """ if not isinstance(ops, (list, tuple, Array)): ops = [ops] - return _ffi_api.CreatePrimFunc(ops, tir_var_list) + return _ffi_api.CreatePrimFunc(ops, tir_var_list, index_dtype_override) def create_prim_func_from_outputs( diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index ad9274f874..4637ad1986 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -536,7 +536,7 @@ class FusedTIRConstructor : public ExprVisitor { body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, alloc_buffers); body = tir::BlockRealize({}, Bool(true), Downcast(body)); tir::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, - Optional>(), DictAttrs(attr_map)); + DictAttrs(attr_map)); return func; } diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 7af7b678be..38bb0b0fd1 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -94,7 +94,7 @@ Pass MetaScheduleApplyDatabase(Optional work_dir) { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext ctx) { - Database database; + Database database{nullptr}; if (Database::Current().defined()) { database = Database::Current().value(); } else { diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc index 177f890d56..0d239e5fbf 100644 --- a/src/relax/transform/tuning_api/database.cc +++ b/src/relax/transform/tuning_api/database.cc @@ -116,8 +116,7 @@ class JSONDatabaseNode : public DatabaseNode { /*! \brief The path to the measurement table */ String path_measurement_record; /*! \brief All the workloads in the database */ - std::unordered_map + std::unordered_map workloads2idx_; /*! \brief All the tuning records in the database */ std::unordered_map> diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 8562cd15f7..f0e1f95de1 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -416,7 +416,7 @@ Optional DefaultTIRConverterImpl(const Array& args, return NullOpt; } } - PrimFunc func = te::CreatePrimFuncWithConstants(args, constants, DataType::Int(64)); + PrimFunc func = te::CreatePrimFuncWithConstants(args, constants, {}, DataType::Int(64)); bool dynamic_loop_extent = false; tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void { if (const auto* loop = obj.as()) { diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 16bbbc3313..cf80925688 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -510,7 +510,8 @@ PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants, - const Optional>& tir_var_list) { + const Optional>& tir_var_list, + std::optional index_dtype_override) { // Infomations used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(arg_list); // Root body stmts. @@ -538,11 +539,22 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, } PrimFunc CreatePrimFunc(const Array& arg_list, - const Optional> tir_var_list) { - return CreatePrimFuncWithConstants(arg_list, {}, tir_var_list); + const Optional> tir_var_list, + std::optional index_dtype_override) { + return CreatePrimFuncWithConstants(arg_list, {}, tir_var_list, index_dtype_override); } -TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); +// TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); +TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, TVMRetValue* ret) { + Array arg_list = args[0]; + Optional> tir_var_list = args[1]; + std::optional index_dtype_override{std::nullopt}; + // Add conversion to make std::optional compatible with FFI. + if (args[2].type_code() != kTVMNullptr) { + index_dtype_override = args[2].operator DataType(); + } + *ret = CreatePrimFunc(arg_list, tir_var_list, index_dtype_override); +}); } // namespace tir } // namespace tvm diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index 483c59f324..e032336758 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -31,7 +31,8 @@ namespace tir { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ PrimFunc CreatePrimFunc(const Array& arg_list, - const Optional> tir_var_list); + const Optional> tir_var_list, + std::optional index_dtype_override); /*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the * constants array is N, the last N tensors in arg_list will be treated as constant tensors. @@ -40,7 +41,8 @@ PrimFunc CreatePrimFunc(const Array& arg_list, */ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants, - const Optional>& tir_var_list); + const Optional>& tir_var_list, + std::optional index_dtype_override = std::nullopt); } // namespace tir } // namespace tvm diff --git a/tests/python/contrib/test_hexagon/test_relax_integration.py b/tests/python/contrib/test_hexagon/test_relax_integration.py index 1947cc7eae..eb356f16e8 100644 --- a/tests/python/contrib/test_hexagon/test_relax_integration.py +++ b/tests/python/contrib/test_hexagon/test_relax_integration.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=missing-function-docstring,invalid-name,line-too-long +"""Relax hexagon test.""" import numpy as np import pytest @@ -21,7 +23,6 @@ from tvm import relay, relax, runtime from tvm.relax.testing import relay_translator from tvm.contrib.hexagon.session import Session -from tvm.script import relax as R, tir as T from tvm.relay import testing @@ -42,7 +43,6 @@ def test_conv2d(hexagon_session: Session): f = relay.Function([data, weight], y) relay_mod = tvm.IRModule.from_expr(f) - # target_hexagon = "llvm -keys=hexagon -link-params=0 -mattr=+hvxv69,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp -mcpu=hexagonv69 -mtriple=hexagon" target_hexagon = tvm.target.hexagon("v68") target = tvm.target.Target(target_hexagon, host=target_hexagon) relax_mod = relay_translator.from_relay(relay_mod["main"], target) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 7c8afec7ef..a6c94096f7 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -21,9 +21,9 @@ import tvm import tvm.testing from tvm import IRModule, relax, tir -from tvm.script._parser import ir as I -from tvm.script._parser import relax as R -from tvm.script._parser import tir as T +from tvm.script.parser import ir as I +from tvm.script.parser import relax as R +from tvm.script.parser import tir as T from tvm.relax import RuntimeDepShape, DynTensorType @@ -610,7 +610,7 @@ def foo(x: R.Tensor((10, 5), "float32")) -> R.Tensor((10, 5), "float32"): # Current error: `gv2.shape` is different: (10, 5) vs RuntimeDepShape() # tvm.ir.assert_structural_equal(Mod0, Mod1) - with pytest.raises(tvm.error.DiagnosticError): + with pytest.raises(OSError): @I.ir_module class ErrorMod: