Skip to content

Commit

Permalink
[Pass] New Python ExprVisitor/ExprMutator! (tlc-pack#190)
Browse files Browse the repository at this point in the history
Add decorators `visitor` and `mutator` to help users create `ExprVisitor` and `ExprMutator` in Python. Users can customize visit/rewrite/post-order-rewrite function in Python.  `PyExprVisitor` and `PyExprMutator` lists the functions users can customize.
  • Loading branch information
LeshengJin authored and junrushao committed Feb 5, 2023
1 parent f4a55ea commit 8d5e8fe
Show file tree
Hide file tree
Showing 10 changed files with 2,606 additions and 535 deletions.
538 changes: 538 additions & 0 deletions include/tvm/relax/expr_functor.h

Large diffs are not rendered by default.

50 changes: 38 additions & 12 deletions python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,27 @@ def _extract(inst: type, name: str):
def method(*args, **kwargs):
return getattr(inst, name)(*args, **kwargs)

if getattr(base, name) is getattr(cls, name) and name != "__str__":
# for task scheduler return None means calling default function
# otherwise it will trigger a TVMError of method not implemented
# on the c++ side when you call the method, __str__ not required
return None
return method
for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]):
# extract functions that differ from the base class
if not hasattr(base_cls, name):
continue
if getattr(base_cls, name) is getattr(inherit_cls, name) and name != "__str__":
continue
return method

# for task scheduler return None means calling default function
# otherwise it will trigger a TVMError of method not implemented
# on the c++ side when you call the method, __str__ not required
return None

assert isinstance(cls.__base__, type)
if hasattr(cls, "_type") and cls._type == "TVMDerivedObject":
raise TypeError(
(
f"Inheritance from a decorated object `{cls.__name__}` is not allowed. "
f"Please inherit from `{cls.__name__}._cls`."
)
)
assert hasattr(
cls, "_tvm_metadata"
), "Please use the user-facing method overriding class, i.e., PyRunner."
Expand All @@ -95,6 +108,9 @@ def method(*args, **kwargs):
class TVMDerivedObject(metadata["cls"]): # type: ignore
"""The derived object to avoid cyclic dependency."""

_cls = cls
_type = "TVMDerivedObject"

def __init__(self, *args, **kwargs):
"""Constructor."""
self.handle = None
Expand All @@ -111,12 +127,22 @@ def __init__(self, *args, **kwargs):
# using weakref to avoid cyclic dependency
self._inst._outer = weakref.ref(self)

def __getattr__(self, name: str):
"""Bridge the attribute function."""
try:
return self._inst.__getattribute__(name)
except AttributeError:
return super(TVMDerivedObject, self).__getattr__(name)
def __getattr__(self, name):
# fall back to instance attribute if there is not any
# return self._inst.__getattribute__(name)
import inspect # pylint: disable=import-outside-toplevel

result = self._inst.__getattribute__(name)
if inspect.ismethod(result):

def method(*args, **kwargs):
return result(*args, **kwargs)

# set __own__ to aviod implicit deconstruction
setattr(method, "__own__", self)
return method

return result

def __setattr__(self, name, value):
if name not in ["_inst", "key", "handle"]:
Expand Down
5 changes: 2 additions & 3 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,5 @@

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprVisitor = expr_functor.ExprVisitor
ExprMutatorBase = expr_functor.ExprMutatorBase
ExprMutator = expr_functor.ExprMutator
PyExprVisitor = expr_functor.PyExprVisitor
PyExprMutator = expr_functor.PyExprMutator
Loading

0 comments on commit 8d5e8fe

Please sign in to comment.