forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpython_nested_functions_manual.cpp
44 lines (37 loc) · 1.28 KB
/
python_nested_functions_manual.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <torch/csrc/utils/nested.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/torch.h>
namespace torch {
namespace autograd {
static PyObject* THPVariable_nested_tensor(
PyObject* /*self*/,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"nested_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
});
constexpr int ctor_num_args = 5;
ParsedArgs<ctor_num_args> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
jit::tracer::warn(
"torch.nested.nested_tensor", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::nested_tensor_ctor(
torch::tensors::get_default_dispatch_key(),
torch::tensors::get_default_scalar_type(),
r));
END_HANDLE_TH_ERRORS
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static PyMethodDef nested_functions_manual[] = {
{"nested_tensor",
castPyCFunctionWithKeywords(THPVariable_nested_tensor),
METH_VARARGS | METH_KEYWORDS,
nullptr},
};
PyMethodDef* get_nested_functions_manual() {
return nested_functions_manual;
}
} // namespace autograd
} // namespace torch