forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTorchDispatchModeTLS.cpp
58 lines (48 loc) · 1.59 KB
/
TorchDispatchModeTLS.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <ATen/core/TorchDispatchModeTLS.h>
#include <c10/core/SafePyObject.h>
#include <c10/core/DispatchKeySet.h>
namespace at { namespace impl {
thread_local std::shared_ptr<SafePyObject> torchDispatchModeState;
void TorchDispatchModeTLS::set_state(std::shared_ptr<SafePyObject> state) {
if (state) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(DispatchKey::PythonTLSSnapshot, true);
} else {
TorchDispatchModeTLS::reset_state();
}
torchDispatchModeState = std::move(state);
}
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_state() {
return torchDispatchModeState;
}
void TorchDispatchModeTLS::reset_state() {
torchDispatchModeState.reset();
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(DispatchKey::PythonTLSSnapshot, false);
}
bool dispatch_mode_enabled() {
return static_cast<bool>(at::impl::TorchDispatchModeTLS::get_state());
}
bool tensor_has_dispatch(const at::Tensor& t) {
DispatchKeySet key_set({DispatchKey::Python, DispatchKey::PythonTLSSnapshot});
return t.key_set().has_any(key_set);
}
bool tensorlist_has_dispatch(const at::TensorList& li) {
for (const auto& t: li) {
if (tensor_has_dispatch(t)) {
return true;
}
}
return false;
}
bool tensorlist_has_dispatch(const c10::List<c10::optional<at::Tensor>>& li) {
for (auto i : c10::irange(li.size())) {
auto t = li.get(i);
if (t && tensor_has_dispatch(*t)) {
return true;
}
}
return false;
}
} // namespace impl
} // namespace at