forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpmmReduceKernel.cpp
565 lines (492 loc) · 19.6 KB
/
SpmmReduceKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/SpmmReduceKernel.h>
#include <ATen/native/cpu/ReduceUtils.h>
#include <ATen/native/cpu/utils.h>
#include <c10/util/irange.h>
#include <ATen/OpMathType.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/zeros.h>
#endif
namespace at::native {
namespace {
template <typename scalar_t, typename index_t, ReductionType reduce>
inline void _update(at::opmath_type<scalar_t>* out_ptr, int64_t e, int64_t c, const scalar_t val, const scalar_t* other_data, int64_t K) {
using opmath_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<scalar_t>;
using aVec = VecType<scalar_t>;
constexpr int64_t kVecSize = Vec::size();
constexpr int64_t kVLEN = kVecSize * 4;
int64_t k = 0;
aVec val_vec = aVec((opmath_t)val);
const scalar_t* other_ptr = other_data + c * K;
for (; k < K - (K % kVLEN); k += kVLEN) {
aVec out_vec0 = aVec::loadu(out_ptr + k);
aVec out_vec1 = aVec::loadu(out_ptr + k + kVecSize);
aVec out_vec2 = aVec::loadu(out_ptr + k + kVecSize * 2);
aVec out_vec3 = aVec::loadu(out_ptr + k + kVecSize * 3);
out_vec0 = update<aVec, reduce>(out_vec0, aVec::loadu(other_ptr + k) * val_vec);
out_vec1 = update<aVec, reduce>(out_vec1, aVec::loadu(other_ptr + k + kVecSize) * val_vec);
out_vec2 = update<aVec, reduce>(out_vec2, aVec::loadu(other_ptr + k + kVecSize * 2) * val_vec);
out_vec3 = update<aVec, reduce>(out_vec3, aVec::loadu(other_ptr + k + kVecSize * 3) * val_vec);
out_vec0.store(out_ptr + k);
out_vec1.store(out_ptr + k + kVecSize);
out_vec2.store(out_ptr + k + kVecSize * 2);
out_vec3.store(out_ptr + k + kVecSize * 3);
}
for (; k < K - (K % kVecSize); k += kVecSize) {
aVec out_vec = aVec::loadu(out_ptr + k);
out_vec = update<aVec, reduce>(out_vec, aVec::loadu(other_ptr + k) * val_vec);
out_vec.store(out_ptr + k);
}
for (; k < K; k++) {
opmath_t out_val = opmath_t(out_ptr[k]);
out_val = update<opmath_t, reduce>(out_val, opmath_t(other_ptr[k]) * opmath_t(val));
out_ptr[k] = out_val;
}
}
template <typename scalar_t, typename index_t, ReductionType reduce>
void spmm_reduce_kernel_impl(
const Tensor& out,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
const Tensor& other_) {
int64_t nnz = values.numel();
if (nnz == 0) {
return;
}
auto other = other_.contiguous();
// access `crow_indices`, `col_indices` and `values` via TensorAccessor
scalar_t* out_data = out.data_ptr<scalar_t>();
auto csr_data = crow_indices.accessor<const index_t, 1>();
auto col_data = col_indices.accessor<const index_t, 1>();
auto val_data = values.accessor<const scalar_t, 1>();
const scalar_t* other_data = other.const_data_ptr<scalar_t>();
int64_t M = crow_indices.numel() - 1;
int64_t K = other.size(-1);
int num_threads = at::get_num_threads();
using opmath_t = at::opmath_type<scalar_t>;
Tensor buffer;
opmath_t* buffer_data = nullptr;
static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
if constexpr (need_acc) {
auto acc_type = at::toAccumulateType(out.scalar_type(), /*is_cuda=*/true);
buffer = at::zeros({num_threads, K}, out.options().dtype(acc_type));
buffer_data = buffer.data_ptr<opmath_t>();
}
utils::parallel_sparse_csr(csr_data, M, nnz, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
TORCH_CHECK(tid < num_threads,
"expect thread id smaller than ", num_threads, ", got thread id ", tid);
opmath_t* buffer_ptr = nullptr;
int64_t row_start = 0, row_end = 0;
for (const auto m : c10::irange(begin, end)) {
row_start = csr_data[m];
row_end = csr_data[m + 1];
scalar_t* out_ptr = out_data + m * K;
if constexpr (need_acc) {
buffer_ptr = buffer_data + tid * K;
} else {
buffer_ptr = reinterpret_cast<opmath_t*>(out_ptr);
}
// step 1: reinit the output row for reduce type 'amax' and 'amin'
int64_t count = row_end - row_start;
if (count != 0) {
_init<scalar_t, reduce>(out_ptr, buffer_ptr, K, /*include_self*/false);
}
// step 2: reduce, do blocking on rowwise to reduce write memory bandwidth
constexpr int64_t CHUNK_SIZE = 16;
for (int64_t e0 = row_start; e0 < row_end; e0 += CHUNK_SIZE) {
int64_t e1 = std::min(e0 + CHUNK_SIZE, row_end);
for (const auto e : c10::irange(e0, e1)) {
int64_t c = col_data[e];
scalar_t val = val_data[e];
_update<scalar_t, index_t, reduce>(buffer_ptr, e, c, val, other_data, K);
}
}
if constexpr (need_acc) {
if (count != 0) {
vec::convert(buffer_ptr, out_ptr, K);
}
}
// step 3: finalize
write<scalar_t, reduce>(out_ptr, count, K);
}
});
}
// update both val and arg, used for `amin` and `amax`
// it is a little troublesome to vectorize it since `scalar_t` and `index_t`
// might have different vector length, for example, each vector holds 8 floats
// and 4 int64_t.
template <typename scalar_t, typename index_t, ReductionType reduce>
inline void update_with_index(scalar_t *val, scalar_t new_val, index_t *arg, index_t new_arg) {
if ((reduce == ReductionType::MIN && new_val < *val) ||
(reduce == ReductionType::MAX && new_val > *val) ||
at::_isnan<scalar_t>(new_val)) {
*val = new_val;
*arg = new_arg;
}
}
template <typename scalar_t, typename index_t, ReductionType reduce>
void spmm_reduce_arg_kernel_impl(
const Tensor& out,
const Tensor& arg_out,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
const Tensor& other_) {
TORCH_CHECK(reduce == ReductionType::MAX || reduce == ReductionType::MIN);
int64_t nnz = values.numel();
if (nnz == 0) {
return;
}
auto other = other_.contiguous();
scalar_t* out_data = out.data_ptr<scalar_t>();
index_t* arg_out_data = arg_out.data_ptr<index_t>();
auto csr_data = crow_indices.accessor<const index_t, 1>();
auto col_data = col_indices.accessor<const index_t, 1>();
auto val_data = values.accessor<const scalar_t, 1>();
const scalar_t* other_data = other.const_data_ptr<scalar_t>();
int64_t M = crow_indices.numel() - 1;
int64_t K = other.size(-1);
int num_threads = at::get_num_threads();
using opmath_t = at::opmath_type<scalar_t>;
Tensor buffer;
opmath_t* buffer_data = nullptr;
static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
if constexpr (need_acc) {
auto acc_type = at::toAccumulateType(out.scalar_type(), /*is_cuda=*/true);
buffer = at::zeros({num_threads, K}, out.options().dtype(acc_type));
buffer_data = buffer.data_ptr<opmath_t>();
}
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
TORCH_CHECK(tid < num_threads,
"expect thread id smaller than ", num_threads, ", got thread id ", tid);
opmath_t* buffer_ptr = nullptr;
int64_t row_start = 0, row_end = 0, c = 0;
for (const auto m : c10::irange(begin, end)) {
row_start = csr_data[m];
row_end = csr_data[m + 1];
scalar_t* out_ptr = out_data + m * K;
index_t* arg_out_ptr = arg_out_data + m * K;
if constexpr (need_acc) {
buffer_ptr = buffer_data + tid * K;
} else {
buffer_ptr = reinterpret_cast<opmath_t*>(out_ptr);
}
if (row_end != row_start) {
_init<scalar_t, reduce>(out_ptr, buffer_ptr, K, /*include_self*/false);
for (const auto e : c10::irange(row_start, row_end)) {
c = col_data[e];
opmath_t val = opmath_t(val_data[e]);
const scalar_t* other_ptr = other_data + c * K;
for (const auto k : c10::irange(K)) {
update_with_index<opmath_t, index_t, reduce>(
&buffer_ptr[k], opmath_t(val * other_ptr[k]), &arg_out_ptr[k], index_t(e));
};
}
}
if constexpr (need_acc) {
if (row_end != row_start) {
vec::convert(buffer_ptr, out_ptr, K);
}
}
}
});
}
template <typename scalar_t, typename index_t, ReductionType reduce>
void spmm_reduce_backward_input_kernel_impl(
const Tensor& grad_self,
const Tensor& grad_out_,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& other_,
const Tensor& row_indices) {
int64_t nnz = grad_self._nnz();
if (nnz == 0) {
return;
}
auto grad_out = grad_out_.contiguous();
auto other = other_.contiguous();
auto values = grad_self.values();
auto grad_values_data = values.accessor<scalar_t, 1>();
const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
auto crow_data = crow_indices.accessor<const index_t, 1>();
auto col_data = col_indices.accessor<const index_t, 1>();
const scalar_t* other_data = other.const_data_ptr<scalar_t>();
auto row_data = row_indices.accessor<const index_t, 1>();
int64_t K = grad_out.size(1);
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
index_t row = row_data[i], col = col_data[i];
scalar_t val = vec::map2_reduce_all<scalar_t>(
[](Vec x, Vec y) { return x * y; },
[](Vec x, Vec y) { return x + y; },
other_data + col * K,
grad_out_data + row * K,
K);
if (reduce == ReductionType::MEAN) {
index_t row_start = crow_data[row], row_end = crow_data[row + 1];
val /= (row_end - row_start);
}
grad_values_data[i] = val;
}
});
}
// backward for reduce type 'amax' or 'amin'
template <typename scalar_t, typename index_t>
void spmm_reduce_backward_input_arg_kernel_impl(
const Tensor& grad_self,
const Tensor& grad_out_,
const Tensor& col_indices,
const Tensor& other_,
const Tensor& arg_out_) {
int64_t nnz = grad_self._nnz();
if (nnz == 0) {
return;
}
auto grad_out = grad_out_.contiguous();
auto other = other_.contiguous();
auto arg_out = arg_out_.contiguous();
auto grad_values = grad_self.values();
auto grad_values_data = grad_values.accessor<scalar_t, 1>();
const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
auto col_data = col_indices.accessor<const index_t, 1>();
const scalar_t* other_data = other.const_data_ptr<scalar_t>();
index_t* arg_out_data = arg_out.data_ptr<index_t>();
int64_t M = grad_out.size(0);
int64_t K = grad_out.size(1);
auto grad = at::empty({M, K}, grad_out.options());
scalar_t* grad_data = grad.mutable_data_ptr<scalar_t>();
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
for (const auto m : c10::irange(begin, end)) {
const scalar_t* grad_out_ptr = grad_out_data + m * K;
scalar_t* grad_ptr = grad_data + m * K;
index_t* arg_out_ptr = arg_out_data + m * K;
for (const auto k : c10::irange(K)) {
if (arg_out_ptr[k] == index_t(nnz)) {
grad_ptr[k] = scalar_t(0);
} else {
// collect weight at max/min indices
index_t col = col_data[arg_out_data[m * K + k]];
grad_ptr[k] = other_data[col * K + k] * grad_out_ptr[k];
}
}
}
});
// scatter_add, consider to parallel this with atomic
for (const auto i : c10::irange(M * K)) {
index_t ind = arg_out_data[i];
if (ind != index_t(nnz)) {
grad_values_data[ind] += grad_data[i];
}
}
}
template <typename scalar_t, typename index_t>
void spmm_reduce_normalize_values_kernel_impl(
const Tensor& normalized_values,
const Tensor& values,
const Tensor& crow_indices,
const Tensor& row_indices) {
int64_t nnz = values.numel();
if (nnz == 0) {
return;
}
auto normalized_values_data = normalized_values.accessor<scalar_t, 1>();
auto values_data = values.accessor<scalar_t, 1>();
auto crow_data = crow_indices.accessor<index_t, 1>();
auto row_data = row_indices.accessor<index_t, 1>();
at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
index_t row = row_data[i];
index_t row_start = crow_data[row], row_end = crow_data[row + 1];
// Note that when the row index row is listed in row_indices,
// then crow_indices[row+1] > crow_indices[row] holds
normalized_values_data[i] = values_data[i] / (row_end - row_start);
}
});
}
template <typename scalar_t, typename index_t>
void spmm_reduce_backward_other_arg_kernel_impl(
const Tensor& grad_other,
const Tensor& grad_out_,
const Tensor& col_indices,
const Tensor& values,
const Tensor& arg_out_) {
int64_t nnz = values.numel();
if (nnz == 0) {
return;
}
auto grad_out = grad_out_.contiguous();
auto arg_out = arg_out_.contiguous();
scalar_t* grad_other_data = grad_other.data_ptr<scalar_t>();
const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
auto col_data = col_indices.accessor<const index_t, 1>();
auto values_data = values.accessor<const scalar_t, 1>();
const index_t* arg_out_data = arg_out.const_data_ptr<index_t>();
int64_t M = grad_out.size(0);
int64_t K = grad_out.size(1);
auto grad = at::empty({M, K}, grad_out.options());
scalar_t* grad_data = grad.mutable_data_ptr<scalar_t>();
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
for (const auto m : c10::irange(begin, end)) {
const scalar_t* grad_out_ptr = grad_out_data + m * K;
scalar_t* grad_ptr = grad_data + m * K;
const index_t* arg_out_ptr = arg_out_data + m * K;
for (const auto k : c10::irange(K)) {
if (arg_out_ptr[k] == index_t(nnz)) {
grad_ptr[k] = scalar_t(0);
} else {
grad_ptr[k] = values_data[arg_out_ptr[k]] * grad_out_ptr[k];
}
}
}
});
// scatter_add, consider to parallel this with atomic
for (const auto m : c10::irange(M)) {
for (const auto k : c10::irange(K)) {
index_t ind = arg_out_data[m * K + k];
if (ind != index_t(nnz)) {
index_t col = col_data[ind];
grad_other_data[col * K + k] += grad_data[m * K + k];
}
}
}
}
void spmm_reduce_kernel(
const Tensor& out,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
const Tensor& other,
ReductionType reduce_op) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
out, crow_indices, col_indices, values, other);
});
});
});
}
void spmm_reduce_arg_kernel(
const Tensor& out,
const Tensor& arg_out,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
const Tensor& other,
ReductionType reduce_op) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_arg_kernel_impl<scalar_t, index_t, reduce>(
out, arg_out, crow_indices, col_indices, values, other);
});
});
});
}
void spmm_reduce_backward_input_kernel(
const Tensor& grad_self,
const Tensor& grad_out,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& other,
const Tensor& row_indices,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_backward_input_kernel_impl<scalar_t, index_t, reduce>(
grad_self, grad_out, crow_indices, col_indices, other, row_indices);
});
});
});
}
void spmm_reduce_backward_input_arg_kernel(
const Tensor& grad_self,
const Tensor& grad_out,
const Tensor& col_indices,
const Tensor& other,
const Tensor& arg_out,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_arg_indices", [&]() {
spmm_reduce_backward_input_arg_kernel_impl<scalar_t, index_t>(
grad_self, grad_out, col_indices, other, arg_out);
});
});
}
void spmm_reduce_normalize_values_kernel(
const Tensor& normalized_values,
const Tensor& values,
const Tensor& crow_indices,
const Tensor& row_indices) {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "spmm_reduce_normalize_values_indices", [&]() {
spmm_reduce_normalize_values_kernel_impl<scalar_t, index_t>(
normalized_values, values, crow_indices, row_indices);
});
});
}
void spmm_reduce_backward_other_kernel(
const Tensor& grad_other,
const Tensor& grad_out,
const Tensor& crow_indices,
const Tensor& values,
const Tensor& row_indices,
const Tensor& ccol_indices,
const Tensor& csr2csc,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
// need to permute row_indices to CSC order
auto row = row_indices.index_select(0, csr2csc);
Tensor val;
if (reduce_op == ReductionType::MEAN) {
// for reduce type "mean", need to normalize the values
// with rowcount for each of the nonzero element.
Tensor normalized_values = at::empty(values.sizes(), values.options());
spmm_reduce_normalize_values_kernel(normalized_values, values, crow_indices, row_indices);
val = normalized_values.index_select(0, csr2csc);
} else {
val = values.index_select(0, csr2csc);
}
spmm_reduce_kernel(grad_other, ccol_indices, row, val, grad_out, ReductionType::SUM);
}
void spmm_reduce_backward_other_arg_kernel(
const Tensor& grad_other,
const Tensor& grad_out,
const Tensor& col_indices,
const Tensor& values,
const Tensor& arg_out,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_other_arg_indices", [&]() {
spmm_reduce_backward_other_arg_kernel_impl<scalar_t, index_t>(
grad_other, grad_out, col_indices, values, arg_out);
});
});
}
} // anonymous namespace
REGISTER_DISPATCH(spmm_reduce_stub, &spmm_reduce_kernel)
REGISTER_DISPATCH(spmm_reduce_arg_stub, &spmm_reduce_arg_kernel)
REGISTER_DISPATCH(spmm_reduce_backward_input_stub, &spmm_reduce_backward_input_kernel)
REGISTER_DISPATCH(spmm_reduce_backward_input_arg_stub, &spmm_reduce_backward_input_arg_kernel)
REGISTER_DISPATCH(spmm_reduce_backward_other_stub, &spmm_reduce_backward_other_kernel)
REGISTER_DISPATCH(spmm_reduce_backward_other_arg_stub, &spmm_reduce_backward_other_arg_kernel)
} // at::native