forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVariableFallbackKernel.cpp
72 lines (57 loc) · 2.12 KB
/
VariableFallbackKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <torch/library.h>
/*
* This file implements a variable fallback kernel for custom operators.
* Since tensors always have the Autograd set, but custom operators
* usually don't have a kernel registered for Autograd, the dispatcher
* will call into this fallback kernel instead.
* Note that this is not a correct autograd implementation. It will just
* fallthrough to the custom operator implementation.
* If you want a custom operator to work with autograd, you need to use
* autograd::Function so that the custom operator implementation knows how to
* do autograd.
* Note also that ops from native_functions.yaml register their own variable
* kernels, so this is never called for them.
*/
// TODO This whole file should be deleted and replaced with the mechanism
// described in https://github.com/pytorch/pytorch/issues/29548
using c10::OperatorHandle;
using c10::Stack;
using c10::DispatchKey;
using c10::DispatchKeySet;
using c10::Dispatcher;
using c10::KernelFunction;
namespace {
// Register fallthrough for Autograd backends dispatch keys
// NB: But not the private use ones; maybe the extension wants
// to override it themselves!
TORCH_LIBRARY_IMPL(_, AutogradOther, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(_, AutogradCPU, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(_, AutogradXPU, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(_, AutogradCUDA, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(_, AutogradLazy, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(_, AutogradMPS, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
// see Note [ADInplaceOrView key]
TORCH_LIBRARY_IMPL(_, ADInplaceOrView, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
}