forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_class_detail.h
239 lines (203 loc) · 7.6 KB
/
custom_class_detail.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
#pragma once
#include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h>
#include <ATen/core/function.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeTraits.h>
#include <c10/util/irange.h>
namespace torch {
namespace detail {
/**
* In the Facebook internal build (using BUCK), this macro is enabled by
* passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
* binary.
*/
#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
TORCH_API void record_custom_class(std::string name);
/**
* Record an instance of a custom class being loaded
* grab portion of string after final '.' from qualified name
* as this seemingly aligns with how users name their custom classes
* example: __torch__.torch.classes.xnnpack.Conv2dOpContext
*/
#define RECORD_CUSTOM_CLASS(NAME) \
auto name = std::string(NAME); \
detail::record_custom_class(name.substr(name.find_last_of(".") + 1));
#else
#define RECORD_CUSTOM_CLASS(NAME)
#endif
} // namespace detail
/// This struct is used to represent default values for arguments
/// when registering methods for custom classes.
/// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
/// .def("myMethod", &Foo::myMethod, {torch::arg("name") = name});
struct arg {
// Static method for representing a default value of None. This is meant to
// be used like so:
// torch::arg("name") = torch::arg::none
// and is identical to:
// torch::arg("name") = IValue()
static c10::IValue none() {
return c10::IValue();
}
// Explicit constructor.
explicit arg(std::string name)
: name_(std::move(name)), value_(c10::nullopt) {}
// Assignment operator. This enables the pybind-like syntax of
// torch::arg("name") = value.
arg& operator=(const c10::IValue& rhs) {
value_ = rhs;
return *this;
}
// The name of the argument. This is copied to the schema; argument
// names cannot be extracted from the C++ declaration.
std::string name_;
// IValue's default constructor makes it None, which is not distinguishable
// from an actual, user-provided default value that is None. This boolean
// helps distinguish between the two cases.
c10::optional<c10::IValue> value_;
};
namespace detail {
// Argument type utilities
template <class R, class...>
struct types {
using type = types;
};
template <typename Method>
struct WrapMethod;
template <typename R, typename CurrClass, typename... Args>
struct WrapMethod<R (CurrClass::*)(Args...)> {
WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {}
R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
return c10::guts::invoke(m, *cur, args...);
}
R (CurrClass::*m)(Args...);
};
template <typename R, typename CurrClass, typename... Args>
struct WrapMethod<R (CurrClass::*)(Args...) const> {
WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {}
R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
return c10::guts::invoke(m, *cur, args...);
}
R (CurrClass::*m)(Args...) const;
};
// Adapter for different callable types
template <
typename CurClass,
typename Func,
std::enable_if_t<
std::is_member_function_pointer<std::decay_t<Func>>::value,
bool> = false>
WrapMethod<Func> wrap_func(Func f) {
return WrapMethod<Func>(std::move(f));
}
template <
typename CurClass,
typename Func,
std::enable_if_t<
!std::is_member_function_pointer<std::decay_t<Func>>::value,
bool> = false>
Func wrap_func(Func f) {
return f;
}
template <
class Functor,
bool AllowDeprecatedTypes,
size_t... ivalue_arg_indices>
typename c10::guts::infer_function_traits_t<Functor>::return_type
call_torchbind_method_from_stack(
Functor& functor,
jit::Stack& stack,
std::index_sequence<ivalue_arg_indices...>) {
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
// be unused and we have to silence the compiler warning.
constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices);
using IValueArgTypes =
typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
// TODO We shouldn't use c10::impl stuff directly here. We should use the
// KernelFunction API instead.
return (functor)(c10::impl::ivalue_to_arg<
typename c10::impl::decay_if_not_tensor<
c10::guts::typelist::
element_t<ivalue_arg_indices, IValueArgTypes>>::type,
AllowDeprecatedTypes>::
call(torch::jit::peek(
stack, ivalue_arg_indices, num_ivalue_args))...);
}
template <class Functor, bool AllowDeprecatedTypes>
typename c10::guts::infer_function_traits_t<Functor>::return_type
call_torchbind_method_from_stack(Functor& functor, jit::Stack& stack) {
constexpr size_t num_ivalue_args =
c10::guts::infer_function_traits_t<Functor>::number_of_parameters;
return call_torchbind_method_from_stack<Functor, AllowDeprecatedTypes>(
functor, stack, std::make_index_sequence<num_ivalue_args>());
}
template <class RetType, class Func>
struct BoxedProxy;
template <class RetType, class Func>
struct BoxedProxy {
void operator()(jit::Stack& stack, Func& func) {
auto retval = call_torchbind_method_from_stack<Func, false>(func, stack);
constexpr size_t num_ivalue_args =
c10::guts::infer_function_traits_t<Func>::number_of_parameters;
torch::jit::drop(stack, num_ivalue_args);
stack.emplace_back(c10::ivalue::from(std::move(retval)));
}
};
template <class Func>
struct BoxedProxy<void, Func> {
void operator()(jit::Stack& stack, Func& func) {
call_torchbind_method_from_stack<Func, false>(func, stack);
constexpr size_t num_ivalue_args =
c10::guts::infer_function_traits_t<Func>::number_of_parameters;
torch::jit::drop(stack, num_ivalue_args);
stack.emplace_back(c10::IValue());
}
};
inline bool validIdent(size_t i, char n) {
return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
}
inline void checkValidIdent(const std::string& str, const char* type) {
for (const auto i : c10::irange(str.size())) {
TORCH_CHECK(
validIdent(i, str[i]),
type,
" must be a valid Python/C++ identifier."
" Character '",
str[i],
"' at index ",
i,
" is illegal.");
}
}
class TORCH_API class_base {
protected:
explicit class_base(
const std::string& namespaceName,
const std::string& className,
std::string doc_string,
const std::type_info& intrusivePtrClassTypeid,
const std::type_info& taggedCapsuleClass);
static c10::FunctionSchema withNewArguments(
const c10::FunctionSchema& schema,
std::initializer_list<arg> default_args);
std::string qualClassName;
at::ClassTypePtr classTypePtr;
};
} // namespace detail
TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
TORCH_API void registerCustomClassMethod(std::unique_ptr<jit::Function> method);
// Given a qualified name (e.g. __torch__.torch.classes.Foo), return
// the ClassType pointer to the Type that describes that custom class,
// or nullptr if no class by that name was found.
TORCH_API at::ClassTypePtr getCustomClass(const std::string& name);
// Given an IValue, return true if the object contained in that IValue
// is a custom C++ class, otherwise return false.
TORCH_API bool isCustomClass(const c10::IValue& v);
// This API is for testing purposes ONLY. It should not be used in
// any load-bearing code.
TORCH_API std::vector<c10::FunctionSchema> customClassSchemasForBCCheck();
namespace jit {
using ::torch::registerCustomClass;
using ::torch::registerCustomClassMethod;
} // namespace jit
} // namespace torch