Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograd assertion: stmt->ret_type == stmt->v->ret_type #8444

Closed
oliver-batchelor opened this issue Dec 22, 2023 · 1 comment · Fixed by #8452
Closed

Autograd assertion: stmt->ret_type == stmt->v->ret_type #8444

oliver-batchelor opened this issue Dec 22, 2023 · 1 comment · Fixed by #8452

Comments

@oliver-batchelor
Copy link
Contributor

oliver-batchelor commented Dec 22, 2023

The code below asserts for requires_grad on dirs, (but not params). I have reduced the problem as much as I could.

Internal error occurred. Check out this page for possible solutions:
https://docs.taichi-lang.org/docs/install
Traceback (most recent call last):
  File "/home/oliver/sync/taichi-splatting/taichi_splatting/asserts.py", line 52, in <module>
    evaluate_sh_kernel.grad(params, dirs, out)
  File "/home/oliver/mambaforge/envs/torch2/lib/python3.10/site-packages/taichi/lang/kernel_impl.py", line 1035, in __call__
    return self.launch_kernel(kernel_cpp, *args)
  File "/home/oliver/mambaforge/envs/torch2/lib/python3.10/site-packages/taichi/lang/kernel_impl.py", line 966, in launch_kernel
    raise e from None
  File "/home/oliver/mambaforge/envs/torch2/lib/python3.10/site-packages/taichi/lang/kernel_impl.py", line 959, in launch_kernel
    compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
RuntimeError: [type_check.cpp:visit@547] Assertion failure: stmt->ret_type == stmt->v->ret_type
import taichi as ti
from taichi.math import vec3, vec4
import torch

dim=3
out_vec = ti.types.vector(dim, ti.f32)

@ti.func
def rsh_cart_1(xyz: vec3) -> vec4:
    x, y, z = xyz
    return vec4(
        0.282094791773878,
        -0.48860251190292 * y,
        0.48860251190292 * z,
        -0.48860251190292 * x,
    )

@ti.kernel    
def evaluate_sh_kernel(params:ti.types.ndarray(vec4, ndim=2), 
    dirs:ti.types.ndarray(vec3, ndim=1), 
    out:ti.types.ndarray(out_vec, ndim=1)):
    
  for i in range(params.shape[0]):
      coeffs = rsh_cart_1(dirs[i])

      for j in range(dim):
        params_j = params[i, j]
        out[i][j] = coeffs.dot(params_j)


if __name__ == '__main__':
    ti.init(debug=True)
    device = 'cpu'

    params = torch.rand(100, dim, 4, device=device, dtype=torch.float32)
    dirs = torch.randn(100, 3, device=device, dtype=torch.float32)
    dirs = torch.nn.functional.normalize(dirs, dim=1)

    out = torch.zeros(dirs.shape[0], dim, dtype=torch.float32, device=params.device)

    dirs.requires_grad_(True) # asserts
    # params.requires_grad_(True) # OK

    evaluate_sh_kernel(params, dirs, out)
    out.grad = torch.ones_like(out).contiguous()
    evaluate_sh_kernel.grad(params, dirs, out)

@oliver-batchelor
Copy link
Contributor Author

Changing
for j in range(dim):

to

for j in ti.static(range(dim)):

Makes the assert go away.

@github-project-automation github-project-automation bot moved this from Untriaged to Done in Taichi Lang Dec 26, 2023
bobcao3 pushed a commit that referenced this issue Dec 26, 2023
Issue: fixes #8444

The return type of cmp statements of tensors should be tensors of u1
instead of tensors of i32.

Sometimes the CFG detects that an AdStackLoadTopStmt and an
AdStackLoadTopAdjStmt loading the same address. I don't know if this
should happen, but it stops raising error if I don't eliminate the
latter statement.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

1 participant