Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Contribution: [Hard] [Operator Development] cummaxmin_backward #397

Open
StrongSpoon opened this issue Jan 3, 2025 · 2 comments
Open
Assignees

Comments

@StrongSpoon
Copy link
Collaborator

Description 任务介绍

Develop backward function for operator cummax and cummin.
开发cummac和cummin算子的反向功能。

Requirements 任务要求

Interface 接口
cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor
Function reference 功能参考
https://pytorch.org/docs/stable/generated/torch.cummax.html#torch-cummax
https://pytorch.org/docs/stable/generated/torch.cummin.html#torch-cummin
Implementation reference 实现参考
https://github.com/FlagOpen/FlagGems/blob/master/src/flag_gems/ops/cummin.py

The operator should support all optional arguments defined in the interface.
算子应支持接口中定义的所有参数选项。
Please provide both accuracy test and performance test code.
请同时提供实现正确性测试与性能测试代码。

DDL 提交时间

Please submit a Pull Request within 3 weeks after accepting the assignment.
请于接取任务后三周内提交PR。

@StrongSpoon StrongSpoon converted this from a draft issue Jan 3, 2025
@2niuhe
Copy link
Contributor

2niuhe commented Jan 13, 2025

2niuhe认领

@Tango2018cc Tango2018cc moved this from Todo to In Progress in Triton China Community Jan 13, 2025
@2niuhe
Copy link
Contributor

2niuhe commented Jan 13, 2025

I noticed while reviewing the PyTorch code that the cummaxmin_backward function is device-agnostic:

- func: cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor
  variants: function
  device_check: NoCheck
  device_guard: False

The implementation of cummaxmin_backward is as follows:

Tensor cummaxmin_backward(const Tensor& grad, const Tensor& input, const Tensor& indices, int64_t dim) {
  if (input.sym_numel() == 0) {
    return input;
  }
  auto result = at::zeros_symint(input.sym_sizes(), input.options());

  // for composite compliance, use out-of-place variant of
  // `scatter_add` if `indices` or `grad` is a Tensor Subclass.
  if (areAnyTensorSubclassLike({indices, grad})) {
    return result.scatter_add(dim, indices, grad);
  }
  return result.scatter_add_(dim, indices, grad);
}

The actual kernel being called is scatter_add_, and I noticed that the flag_gems library already implements the scatter operator. This raises the question: is there still a need to implement cummaxmin_backward, or could we simply wrap the scatter operation to provide an in-place version?

Looking forward to your thoughts!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In Progress
Development

No branches or pull requests

2 participants