forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpython_mode.cpp
27 lines (23 loc) · 956 Bytes
/
python_mode.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <torch/csrc/autograd/python_mode.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/autograd/python_variable.h>
#include <ATen/core/PythonModeTLS.h>
#include <c10/core/TensorImpl.h>
namespace torch { namespace autograd {
void PythonMode::enter(PyObject* type) {
if (at::impl::PythonModeTLS::get_state()) {
TORCH_CHECK(
false,
"python mode has already been set. We do not yet support nested python ",
"mode. Please file us an issue and reset it before setting it again.")
}
// TorchDispatchTypeObject steals a reference, See NOTE [What is TorchDispatchTypeObject?]
Py_INCREF(type);
auto state = std::make_shared<c10::TorchDispatchTypeObject>(type, getPyInterpreter());
at::impl::PythonModeTLS::set_state(state);
}
void PythonMode::exit() {
TORCH_INTERNAL_ASSERT(at::impl::PythonModeTLS::get_state(), "exiting Python Mode but it wasn't set!");
at::impl::PythonModeTLS::reset_state();
}
}}