You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The operator should support all optional arguments defined in the interface.
算子应支持接口中定义的所有参数选项。
Please provide both accuracy test and performance test code.
请同时提供实现正确性测试与性能测试代码。
DDL 提交时间
Please submit a Pull Request within 3 weeks after accepting the assignment.
请于接取任务后三周内提交PR。
The text was updated successfully, but these errors were encountered:
I noticed while reviewing the PyTorch code that the cummaxmin_backward function is device-agnostic:
- func: cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor
variants: function
device_check: NoCheck
device_guard: False
The implementation of cummaxmin_backward is as follows:
Tensor cummaxmin_backward(const Tensor& grad, const Tensor& input, const Tensor& indices, int64_t dim) {
if (input.sym_numel() == 0) {
return input;
}
auto result = at::zeros_symint(input.sym_sizes(), input.options());
// for composite compliance, use out-of-place variant of// `scatter_add` if `indices` or `grad` is a Tensor Subclass.if (areAnyTensorSubclassLike({indices, grad})) {
return result.scatter_add(dim, indices, grad);
}
return result.scatter_add_(dim, indices, grad);
}
The actual kernel being called is scatter_add_, and I noticed that the flag_gems library already implements the scatter operator. This raises the question: is there still a need to implement cummaxmin_backward, or could we simply wrap the scatter operation to provide an in-place version?
Description 任务介绍
Develop backward function for operator cummax and cummin.
开发cummac和cummin算子的反向功能。
Requirements 任务要求
Interface 接口
cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor
Function reference 功能参考
https://pytorch.org/docs/stable/generated/torch.cummax.html#torch-cummax
https://pytorch.org/docs/stable/generated/torch.cummin.html#torch-cummin
Implementation reference 实现参考
https://github.com/FlagOpen/FlagGems/blob/master/src/flag_gems/ops/cummin.py
The operator should support all optional arguments defined in the interface.
算子应支持接口中定义的所有参数选项。
Please provide both accuracy test and performance test code.
请同时提供实现正确性测试与性能测试代码。
DDL 提交时间
Please submit a Pull Request within 3 weeks after accepting the assignment.
请于接取任务后三周内提交PR。
The text was updated successfully, but these errors were encountered: