diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index cd990d0eea..283e903405 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -47,6 +47,20 @@ def _get_cutlass_path(): raise AssertionError(f"The CUTLASS root directory not found in: {invalid_paths}") +def ddd(): + print + + +class SubTreeExtractor(nn.Module): + def __init__(self): + super(SubTreeExtractor, self).__init__() + + def forward(self, x, y): + return x + y + +torch.onnx.export(SubTreeExtractor(), (torch.Tensor(1,2,3,4), 1), "1.onnx") + + def _get_cutlass_compile_options(sm, threads, use_fast_math=False): cutlass_root = _get_cutlass_path() cutlass_include = os.path.join(cutlass_root, "include")