Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[CI] Enable Mypy type checking for Relax; Fix typing errors to pass M…
Browse files Browse the repository at this point in the history
…ypy checking. (#270)
  • Loading branch information
YuchenJin authored Nov 5, 2022
1 parent 1d6d8e7 commit a443b8d
Show file tree
Hide file tree
Showing 22 changed files with 320 additions and 284 deletions.
26 changes: 13 additions & 13 deletions python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def post_order_visit(expr, fvisit):
fvisit : function
The visitor function to be applied.
"""
return _ffi_api.post_order_visit(expr, fvisit)
return _ffi_api.post_order_visit(expr, fvisit) # type: ignore


def well_formed(mod: tvm.IRModule) -> bool:
Expand All @@ -58,7 +58,7 @@ def well_formed(mod: tvm.IRModule) -> bool:
ret: bool
True if the IRModule is well formed, False if not.
"""
return _ffi_api.well_formed(mod)
return _ffi_api.well_formed(mod) # type: ignore


def get_var2val(func: Function) -> Dict[Var, Expr]:
Expand All @@ -75,7 +75,7 @@ def get_var2val(func: Function) -> Dict[Var, Expr]:
Dict[Var, Expr]
A mapping from Var to Expr.
"""
return _ffi_api.get_var2val(func)
return _ffi_api.get_var2val(func) # type: ignore


def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]:
Expand All @@ -92,12 +92,12 @@ def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]:
Dict[Var, List[Var]]
A mapping from variable definition to its uses.
"""
return _ffi_api.udchain(dfb)
return _ffi_api.udchain(dfb) # type: ignore


def name_to_binding(func: Function) -> Dict[str, List[Binding]]:
"""Return a map from variable name to its bindings."""
return _ffi_api.name_to_binding(func)
return _ffi_api.name_to_binding(func) # type: ignore


def remove_all_unused(func: Function) -> Function:
Expand All @@ -113,7 +113,7 @@ def remove_all_unused(func: Function) -> Function:
Function
The function with unused variables removed.
"""
return _ffi_api.remove_all_unused(func)
return _ffi_api.remove_all_unused(func) # type: ignore


def shape_vars(expr: Expr) -> List[tir.Var]:
Expand All @@ -133,7 +133,7 @@ def shape_vars(expr: Expr) -> List[tir.Var]:
ret: List[tir.Var]
A list of all shape variables (TIR variables) in the expression.
"""
return _ffi_api.shape_vars(expr)
return _ffi_api.shape_vars(expr) # type: ignore


def derive_func_ret_shape(args: List[Var], body: Expr) -> Expr:
Expand All @@ -156,7 +156,7 @@ def derive_func_ret_shape(args: List[Var], body: Expr) -> Expr:
ret: Expr
An expression that can serve as the return shape for the function
"""
return _ffi_api.derive_func_ret_shape(args, body)
return _ffi_api.derive_func_ret_shape(args, body) # type: ignore


def bound_vars(expr: Expr) -> List[Var]:
Expand All @@ -176,7 +176,7 @@ def bound_vars(expr: Expr) -> List[Var]:
ret: List[Var]
List of bound vars in expr, in post-DFS order
"""
return _ffi_api.bound_vars(expr)
return _ffi_api.bound_vars(expr) # type: ignore


def free_vars(expr: Expr) -> List[Var]:
Expand All @@ -196,7 +196,7 @@ def free_vars(expr: Expr) -> List[Var]:
ret: List[Var]
List of free vars in expr, in post-DFS order
"""
return _ffi_api.free_vars(expr)
return _ffi_api.free_vars(expr) # type: ignore


def all_vars(expr: Expr) -> List[Var]:
Expand All @@ -213,7 +213,7 @@ def all_vars(expr: Expr) -> List[Var]:
ret: List[Var]
List of vars in expr, in post-DFS order
"""
return _ffi_api.all_vars(expr)
return _ffi_api.all_vars(expr) # type: ignore


def all_global_vars(expr: Expr) -> List[GlobalVar]:
Expand All @@ -230,7 +230,7 @@ def all_global_vars(expr: Expr) -> List[GlobalVar]:
ret: List[GlobalVar]
List of global vars in expr, in post-DFS order
"""
return _ffi_api.all_global_vars(expr)
return _ffi_api.all_global_vars(expr) # type: ignore


def called_global_vars(expr: Expr) -> List[GlobalVar]:
Expand All @@ -248,4 +248,4 @@ def called_global_vars(expr: Expr) -> List[GlobalVar]:
List of global vars that are used recursively in expr,
in post-DFS order
"""
return _ffi_api.called_global_vars(expr)
return _ffi_api.called_global_vars(expr) # type: ignore
16 changes: 9 additions & 7 deletions python/tvm/relax/binding_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def __init__(self, dfb: DataflowBlock, root_fn: Function):
The root function of the DataflowBlock.
"""
self.func_name = root_fn.__name__ if hasattr(root_fn, "__name__") else None
self.__init_handle_by_constructor__(_ffi_api.DataflowBlockRewrite, dfb, root_fn)
self.__init_handle_by_constructor__(
_ffi_api.DataflowBlockRewrite, dfb, root_fn # type: ignore
)

def replace_all_uses(self, old_var: Var, new_var: Var) -> None:
"""
Expand All @@ -64,10 +66,10 @@ def replace_all_uses(self, old_var: Var, new_var: Var) -> None:
new_var : Var
The new variable to replace with.
"""
_ffi_api.dfb_rewrite_replace_all_uses(self, old_var, new_var)
_ffi_api.dfb_rewrite_replace_all_uses(self, old_var, new_var) # type: ignore

def add_binding(self, binding: Binding) -> None:
return _ffi_api.dfb_rewrite_add_binding(self, binding)
return _ffi_api.dfb_rewrite_add_binding(self, binding) # type: ignore

def add(self, expr: Expr, name: Optional[str] = None, is_dfvar: bool = False) -> None:
"""
Expand All @@ -89,7 +91,7 @@ def add(self, expr: Expr, name: Optional[str] = None, is_dfvar: bool = False) ->
it will be Var. Being Var means the variables are output variables of the DataflowBlock.
While being DataflowVar means the variables are internal variables of the DataflowBlock.
"""
_ffi_api.dfb_rewrite_add(self, expr, name, is_dfvar)
_ffi_api.dfb_rewrite_add(self, expr, name, is_dfvar) # type: ignore

def remove_unused(self, var: Var, allow_undef=False) -> None:
"""
Expand All @@ -106,7 +108,7 @@ def remove_unused(self, var: Var, allow_undef=False) -> None:
------
TVMError if the variable is used or undefined (allow_undef=False).
"""
_ffi_api.dfb_rewrite_remove_unused(self, var, allow_undef)
_ffi_api.dfb_rewrite_remove_unused(self, var, allow_undef) # type: ignore

def remove_all_unused(self) -> None:
"""
Expand All @@ -116,7 +118,7 @@ def remove_all_unused(self) -> None:
-----
This could remove unused variables in other DataflowBlocks as well.
"""
_ffi_api.dfb_rewrite_remove_all_unused(self)
_ffi_api.dfb_rewrite_remove_all_unused(self) # type: ignore

def mutated_dfb(self) -> DataflowBlock:
"""
Expand Down Expand Up @@ -147,7 +149,7 @@ def mutate_irmodule(self, irmodule: tvm.IRModule) -> tvm.IRModule:
tvm.IRModule
The updated IRModule.
"""
ret = _ffi_api.dfb_rewrite_mutate_irmodule(self, irmodule)
ret = _ffi_api.dfb_rewrite_mutate_irmodule(self, irmodule) # type: ignore
if hasattr(irmodule, "__name__"):
ret.__name__ = irmodule.__name__
return ret
40 changes: 20 additions & 20 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,19 @@ def current():
return BlockBuilder._current

def __init__(self, mod: IRModule = None):
self._blocks = []
self._blocks: List[BindingBlock] = []
# a boolean flag that tracks if emit_func_output has been called
self._is_emit_func_output_called = False
self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate, mod)
self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate, mod) # type: ignore

def _begin_dataflow_block(self) -> None:
_ffi_api.BlockBuilderBeginDataflowBlock(self)
_ffi_api.BlockBuilderBeginDataflowBlock(self) # type: ignore

def _begin_binding_block(self) -> None:
_ffi_api.BlockBuilderBeginBindingBlock(self)
_ffi_api.BlockBuilderBeginBindingBlock(self) # type: ignore

def _end_block(self) -> BindingBlock:
return _ffi_api.BlockBuilderEndBlock(self)
return _ffi_api.BlockBuilderEndBlock(self) # type: ignore

def _enter_function_scope(self, name, params, attrs):
if BlockBuilder.current() is not None:
Expand Down Expand Up @@ -196,7 +196,7 @@ def _convert_te_arg(self, te_args: Any) -> typing.Tuple[Any, List[tvm.te.Tensor]
te_args_list = []

def _convert_te_arg_helper(arg):
if isinstance(arg, Expr):
if isinstance(arg, Expr): # type: ignore
arg = te_tensor(arg)
te_args_list.append(arg)
return arg
Expand Down Expand Up @@ -307,7 +307,7 @@ def emit(self, expr: Expr) -> Var:
ret : tvm.relax.Var
A newly created variable that gets bound to the input expr.
"""
return _ffi_api.BlockBuilderEmit(self, expr)
return _ffi_api.BlockBuilderEmit(self, expr) # type: ignore

def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr:
"""Generate a call node according to the te function.
Expand Down Expand Up @@ -521,7 +521,7 @@ def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var:
ret : tvm.relax.Var
A newly created variable that gets bound to the call code.
"""
return _ffi_api.BlockBuilderEmitMatchShape(self, value, pattern)
return _ffi_api.BlockBuilderEmitMatchShape(self, value, pattern) # type: ignore

def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None:
"""Emit output for the current dataflow block or function.
Expand All @@ -538,7 +538,7 @@ def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None:
"""
if isinstance(output, (list, tuple)):
output = Tuple(output)
return _ffi_api.BlockBuilderEmitOutput(self, output)
return _ffi_api.BlockBuilderEmitOutput(self, output) # type: ignore

def emit_func_output(
self,
Expand Down Expand Up @@ -610,7 +610,7 @@ def normalize(self, expr: Expr) -> Expr:
ret : Expr
The expr with normalized shape and type.
"""
return _ffi_api.BlockBuilderNormalize(self, expr)
return _ffi_api.BlockBuilderNormalize(self, expr) # type: ignore

def get(self) -> tvm.IRModule:
"""Return the IRModule being built.
Expand All @@ -620,7 +620,7 @@ def get(self) -> tvm.IRModule:
ret : tvm.IRModule
An IRModule with Relax and TIR functions being built.
"""
return _ffi_api.BlockBuilderGetContextIRModule(self)
return _ffi_api.BlockBuilderGetContextIRModule(self) # type: ignore

def get_unique_name(self, name_prefix: str) -> str:
"""Generate a unique name with a specified prefix.
Expand All @@ -635,7 +635,7 @@ def get_unique_name(self, name_prefix: str) -> str:
ret : str
The generated name.
"""
return _ffi_api.BlockBuilderGetUniqueName(self, name_prefix)
return _ffi_api.BlockBuilderGetUniqueName(self, name_prefix) # type: ignore

def add_func(self, func: BaseFunc, func_name: str) -> GlobalVar:
"""Add a Relax function or a TIR PrimFunc to the IRModule being built.
Expand All @@ -653,7 +653,7 @@ def add_func(self, func: BaseFunc, func_name: str) -> GlobalVar:
gvar : GlobalVar
The global var bound to the added function.
"""
return _ffi_api.BlockBuilderAddFunction(self, func, func_name)
return _ffi_api.BlockBuilderAddFunction(self, func, func_name) # type: ignore

def update_func(self, gv: GlobalVar, updated_func: BaseFunc) -> None:
"""Add a Relax function or a TIR PrimFunc to the IRModule being built.
Expand All @@ -666,7 +666,7 @@ def update_func(self, gv: GlobalVar, updated_func: BaseFunc) -> None:
updated_func : BaseFunc
The updated function.
"""
return _ffi_api.BlockBuilderUpdateFunction(self, gv, updated_func)
return _ffi_api.BlockBuilderUpdateFunction(self, gv, updated_func) # type: ignore

def can_prove_shape_equal(self, lhs: Expr, rhs: Expr) -> bool:
"""Check if two shape expressions can be proven equal at compile time.
Expand All @@ -684,7 +684,7 @@ def can_prove_shape_equal(self, lhs: Expr, rhs: Expr) -> bool:
ret : bool
Whether we can prove lhs shape is the same as the rhs shape.
"""
return _ffi_api.BlockBuilderCanProveShapeEqual(self, lhs, rhs)
return _ffi_api.BlockBuilderCanProveShapeEqual(self, lhs, rhs) # type: ignore

def current_block_is_dataflow(self) -> bool:
"""Check if the block being built is DataflowBlock or not.
Expand All @@ -694,7 +694,7 @@ def current_block_is_dataflow(self) -> bool:
ret : bool
A boolean that indicates if the block being built is DataflowBlock or not.
"""
return _ffi_api.BlockBuilderCurrentBlockIsDataFlow(self)
return _ffi_api.BlockBuilderCurrentBlockIsDataFlow(self) # type: ignore

def emit_var_binding(self, binding: VarBinding) -> Var:
"""Emits a variable binding, and returns the bound Var.
Expand All @@ -709,7 +709,7 @@ def emit_var_binding(self, binding: VarBinding) -> Var:
var: Var
The bound variable.
"""
return _ffi_api.BlockBuilderEmitVarBinding(self, binding)
return _ffi_api.BlockBuilderEmitVarBinding(self, binding) # type: ignore

def emit_output_var_binding(self, binding: VarBinding) -> Var:
"""Generate an output for the current dataflow block.
Expand All @@ -724,7 +724,7 @@ def emit_output_var_binding(self, binding: VarBinding) -> Var:
var: Var
The variable bound to output.
"""
return _ffi_api.BlockBuilderEmitOutputVarBinding(self, binding)
return _ffi_api.BlockBuilderEmitOutputVarBinding(self, binding) # type: ignore

def match_shape_binding(self, binding: MatchShape) -> Var:
"""Emit a MatchShape binding.
Expand All @@ -739,7 +739,7 @@ def match_shape_binding(self, binding: MatchShape) -> Var:
var: Var
The variable bound to the MatchShape.
"""
return _ffi_api.BlockBuilderEmitMatchShapeBinding(self, binding)
return _ffi_api.BlockBuilderEmitMatchShapeBinding(self, binding) # type: ignore

def lookup_binding(self, var: Var) -> Optional[Expr]:
"""Lookup a var in the binding table binding_table_.
Expand All @@ -754,4 +754,4 @@ def lookup_binding(self, var: Var) -> Optional[Expr]:
expr: Expr
The Expr bound to the input var.
"""
return _ffi_api.BlockBuilderLookupBinding(self, var)
return _ffi_api.BlockBuilderLookupBinding(self, var) # type: ignore
10 changes: 5 additions & 5 deletions python/tvm/relax/dpl/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ def __init__(self, incremental=False):
incremental : bool, optional
perform incremental matching based on the recent context, by default False
"""
self.__init_handle_by_constructor__(ffi.PatternContext, incremental)
self.__init_handle_by_constructor__(ffi.PatternContext, incremental) # type: ignore

def __enter__(self):
"""Enter the context"""
ffi.enter_context(self)
ffi.enter_context(self) # type: ignore
return self

def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context"""
ffi.exit_context(self)
ffi.exit_context(self) # type: ignore

@staticmethod
def current() -> "PatternContext":
Expand All @@ -58,7 +58,7 @@ def current() -> "PatternContext":
PatternContext
The current context
"""
return ffi.current_context()
return ffi.current_context() # type: ignore

def match_dfb(
self,
Expand All @@ -83,4 +83,4 @@ def match_dfb(
Dict[DFPattern, Var]
The mapping from DFPattern to matched expression
"""
return ffi.match_dfb(self, dfb, start_hint, must_include_hint)
return ffi.match_dfb(self, dfb, start_hint, must_include_hint) # type: ignore
Loading

0 comments on commit a443b8d

Please sign in to comment.