forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_class_detail.h
155 lines (127 loc) · 4.77 KB
/
custom_class_detail.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#pragma once
#include <ATen/core/boxing/kernel_functor.h>
#include <ATen/core/function.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeTraits.h>
namespace torch {
namespace detail {
// Argument type utilities
template <class R, class...>
struct types {
using type = types;
};
template <typename Method>
struct WrapMethod;
template <typename R, typename CurrClass, typename... Args>
struct WrapMethod<R (CurrClass::*)(Args...)> {
WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {}
R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
return c10::guts::invoke(m, *cur, args...);
}
R (CurrClass::*m)(Args...);
};
template <typename R, typename CurrClass, typename... Args>
struct WrapMethod<R (CurrClass::*)(Args...) const> {
WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {}
R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
return c10::guts::invoke(m, *cur, args...);
}
R (CurrClass::*m)(Args...) const;
};
// Adapter for different callable types
template <
typename CurClass,
typename Func,
std::enable_if_t<
std::is_member_function_pointer<std::decay_t<Func>>::value,
bool> = false>
WrapMethod<Func> wrap_func(Func f) {
return WrapMethod<Func>(std::move(f));
}
template <
typename CurClass,
typename Func,
std::enable_if_t<
!std::is_member_function_pointer<std::decay_t<Func>>::value,
bool> = false>
Func wrap_func(Func f) {
return f;
}
template <
class Functor,
bool AllowDeprecatedTypes,
size_t... ivalue_arg_indices>
typename c10::guts::infer_function_traits_t<Functor>::return_type
call_torchbind_method_from_stack(
Functor& functor,
jit::Stack& stack,
std::index_sequence<ivalue_arg_indices...>) {
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
// be unused and we have to silence the compiler warning.
constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices);
using IValueArgTypes =
typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
return (functor)(c10::detail::ivalue_to_arg<
std::remove_cv_t<std::remove_reference_t<
c10::guts::typelist::
element_t<ivalue_arg_indices, IValueArgTypes>>>,
AllowDeprecatedTypes>(std::move(
torch::jit::peek(stack, ivalue_arg_indices, num_ivalue_args)))...);
}
template <class Functor, bool AllowDeprecatedTypes>
typename c10::guts::infer_function_traits_t<Functor>::return_type
call_torchbind_method_from_stack(Functor& functor, jit::Stack& stack) {
constexpr size_t num_ivalue_args =
c10::guts::infer_function_traits_t<Functor>::number_of_parameters;
return call_torchbind_method_from_stack<Functor, AllowDeprecatedTypes>(
functor, stack, std::make_index_sequence<num_ivalue_args>());
}
template <class RetType, class Func>
struct BoxedProxy;
template <class RetType, class Func>
struct BoxedProxy {
void operator()(jit::Stack& stack, Func& func) {
auto retval = call_torchbind_method_from_stack<Func, false>(func, stack);
constexpr size_t num_ivalue_args =
c10::guts::infer_function_traits_t<Func>::number_of_parameters;
torch::jit::drop(stack, num_ivalue_args);
stack.emplace_back(c10::ivalue::from(std::move(retval)));
}
};
template <class Func>
struct BoxedProxy<void, Func> {
void operator()(jit::Stack& stack, Func& func) {
call_torchbind_method_from_stack<Func, false>(func, stack);
constexpr size_t num_ivalue_args =
c10::guts::infer_function_traits_t<Func>::number_of_parameters;
torch::jit::drop(stack, num_ivalue_args);
stack.emplace_back(c10::IValue());
}
};
inline bool validIdent(size_t i, char n) {
return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
}
inline void checkValidIdent(const std::string& str, const char *type) {
for (size_t i = 0; i < str.size(); ++i) {
TORCH_CHECK(validIdent(i, str[i]),
type,
" must be a valid Python/C++ identifier."
" Character '", str[i], "' at index ",
i, " is illegal.");
}
}
} // namespace detail
TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
TORCH_API void registerCustomClassMethod(std::shared_ptr<jit::Function> method);
// Given a qualified name (e.g. __torch__.torch.classes.Foo), return
// the ClassType pointer to the Type that describes that custom class,
// or nullptr if no class by that name was found.
TORCH_API at::ClassTypePtr getCustomClass(const std::string& name);
// Given an IValue, return true if the object contained in that IValue
// is a custom C++ class, otherwise return false.
TORCH_API bool isCustomClass(const c10::IValue& v);
namespace jit {
using ::torch::registerCustomClass;
using ::torch::registerCustomClassMethod;
}
} // namespace torch