diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index c40bb36..f447bbf 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -19,6 +19,7 @@ _implicit_cvt, RESERVED_KWS, interpreter_builder, + InterpretedFunction, ) from triton.runtime.interpreter import _patch_lang as triton_patch_lang from triton.runtime import JITFunction @@ -342,6 +343,8 @@ def wrapper(input, axis=None, keep_dims=False): def patch(): old_grid_executor_call = GridExecutor.__call__ old_jit_function_call = JITFunction.__call__ + # XXX(Keren): Temporarily disable rewriting of AST + old_rewrite_ast = InterpretedFunction._rewrite_ast old_create_make_range = interpreter_builder.create_make_range old_create_masked_load = interpreter_builder.create_masked_load old_create_expand_dims = interpreter_builder.create_expand_dims @@ -350,6 +353,7 @@ def patch(): old_create_masked_store = interpreter_builder.create_masked_store GridExecutor.__call__ = _grid_executor_call JITFunction.__call__ = _jit_function_call + InterpretedFunction._rewrite_ast = lambda self: self.fn interpreter_builder.create_make_range = _create_make_range( interpreter_builder.create_make_range ) @@ -369,6 +373,7 @@ def patch(): finally: GridExecutor.__call__ = old_grid_executor_call JITFunction.__call__ = old_jit_function_call + InterpretedFunction._rewrite_ast = old_rewrite_ast interpreter_builder.create_make_range = old_create_make_range interpreter_builder.create_masked_load = old_create_masked_load interpreter_builder.create_expand_dims = old_create_expand_dims