forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_autograd_functions.py
219 lines (184 loc) · 7.32 KB
/
gen_autograd_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# Generates C++ autograd functions for the derivatives of ATen operations
#
# This writes two files:
# Functions.h/cpp: subclasses of autograd::Function
# python_functions.h/cpp: Python bindings for the above classes
#
import re
from .utils import nested_dict, CodeTemplate, write
from .gen_autograd import VIEW_FUNCTIONS
from .utils import IDENT_REGEX
FUNCTION_DECLARATION = CodeTemplate("""\
struct ${op} : public ${superclass} {
using ${superclass}::${superclass};
variable_list apply(variable_list&& grads) override;
std::string name() const override { return "${op}"; }
void release_variables() override {
${release_variables}
}
${will_release_variables}
${saved_variables}
${saved_list_sizes}
};
""")
WILL_RELEASE_VARIABLES = CodeTemplate("""\
bool retain_variables = true;
void will_release_variables() override {
retain_variables = false;
}
""")
FUNCTION_DEFINITION = CodeTemplate("""\
variable_list ${op}::apply(variable_list&& grads) {
IndexRangeGenerator gen;
${compute_index_ranges}
variable_list grad_inputs(gen.size());
${body}
return grad_inputs;
}
""")
PY_FUNCTION_DEFINITION = CodeTemplate("""\
static PyTypeObject ${op}Class;
addClass<${op}>(${op}Class, "${op}");
""")
GRAD_INPUT_MASK = CodeTemplate("""\
auto grad_input_mask = std::array<bool, ${n}>{
${masks}
};\
""")
DERIVATIVE_SINGLE = CodeTemplate("""\
if (should_compute_output({ ${name}_ix })) {
auto grad_result = ${derivative};
copy_range(grad_inputs, ${name}_ix, grad_result);
}
""")
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate("""\
if (should_compute_output({ ${name}_ix })) {
copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
}
""")
DERIVATIVE_MULTI = CodeTemplate("""\
if (should_compute_output({ ${idx_ranges} })) {
${grad_input_mask}
auto grad_result = ${derivative};
${copy_ranges}
}
""")
# These functions have backwards which cannot be traced, and so must have
# their backward functions traced opaquely.
# VIEW_FUNCTIONS are not traceable because they use as_strided, which
# has an untraceable backwards, see
# https://github.com/pytorch/pytorch/issues/4250
# TODO: This is probably not exhaustive, but it's a start
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def gen_autograd_functions(out, autograd_functions, template_path):
"""Functions.h and Functions.cpp body
These contain the auto-generated subclasses of torch::autograd::Function
for each every differentiable torch function.
"""
FUNCTIONS_H = CodeTemplate.from_file(template_path + '/Functions.h')
FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/Functions.cpp')
PY_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_functions.h')
PY_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_functions.cpp')
function_definitions = []
function_declarations = []
py_function_initializers = []
for func in autograd_functions:
env = process_function(func)
function_declarations.append(FUNCTION_DECLARATION.substitute(env))
function_definitions.append(FUNCTION_DEFINITION.substitute(env))
py_function_initializers.append(PY_FUNCTION_DEFINITION.substitute(env))
top_env = {
'autograd_function_definitions': function_definitions,
'autograd_function_declarations': function_declarations,
'py_function_initializers': py_function_initializers,
}
write(out, 'Functions.h', FUNCTIONS_H, top_env)
write(out, 'Functions.cpp', FUNCTIONS_CPP, top_env)
write(out, 'python_functions.h', PY_FUNCTIONS_H, top_env)
write(out, 'python_functions.cpp', PY_FUNCTIONS_CPP, top_env)
def process_function(func):
env = {}
saved_variables = []
release_variables = []
saved_list_sizes = []
unpack = []
env['compute_index_ranges'] = []
for arg in func['args_with_gradients']:
if arg['type'] == 'TensorList':
size = '{}_size_'.format(arg['name'])
saved_list_sizes.append('size_t {}_size_;'.format(arg['name']))
else:
size = '1'
env['compute_index_ranges'].append('auto {}_ix = gen.range({});'.format(arg['name'], size))
def save_arg(arg, is_output):
name = arg['name']
if arg['type'] == 'Tensor' or (arg['type'] == 'Scalar' and is_output):
saved_variables.append('SavedVariable {}_;'.format(name))
release_variables.append('{}_.reset_data();'.format(name))
release_variables.append('{}_.reset_grad_function();'.format(name))
ptr = 'shared_from_this()' if is_output else ''
unpack.append('auto {} = {}_.unpack({});'.format(name, name, ptr))
elif arg['type'] == 'TensorList':
saved_variables.append('std::vector<SavedVariable> {}_;'.format(name))
release_variables.append('{}_.clear();'.format(name))
unpack.append('auto {} = unpack_list({}_);'.format(name, name))
elif arg['type'] == 'IntList':
saved_variables.append('std::vector<int64_t> {};'.format(name))
elif arg['type'] == 'int64_t':
saved_variables.append('{} {} = 0;'.format(arg['type'], name))
else:
saved_variables.append('{} {};'.format(arg['type'], name))
for arg in func['saved_inputs']:
save_arg(arg, is_output=False)
for arg in func['saved_outputs']:
save_arg(arg, is_output=True)
env['saved_variables'] = saved_variables
env['release_variables'] = release_variables
env['saved_list_sizes'] = saved_list_sizes
if uses_retain_variables(func):
env['will_release_variables'] = WILL_RELEASE_VARIABLES.substitute()
else:
env['will_release_variables'] = ''
body = []
if uses_single_grad(func):
body.append('auto& grad = grads[0];')
def emit_derivative(derivative):
formula = derivative['formula']
var_names = derivative['var_names']
if len(var_names) == 1:
return DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula)
else:
if 'grad_input_mask' in formula:
masks = ['should_compute_output({{ {}_ix }}),'.format(n) for n in var_names]
grad_input_mask = GRAD_INPUT_MASK.substitute(masks=masks, n=len(var_names))
else:
grad_input_mask = ''
idx_ranges = ', '.join("{}_ix".format(n) for n in var_names)
copy_ranges = []
for i, n in enumerate(var_names):
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
return DERIVATIVE_MULTI.substitute(
idx_ranges=idx_ranges, copy_ranges=copy_ranges,
derivative=formula,
grad_input_mask=grad_input_mask)
body.extend(unpack)
for derivative in func['derivatives']:
body.append(emit_derivative(derivative))
env['body'] = body
if func['name'] in UNTRACEABLE_FUNCTIONS:
env['superclass'] = 'Function'
else:
env['superclass'] = 'TraceableFunction'
return nested_dict(env, func)
def uses_ident(func, ident):
if func is None:
return False
for derivative in func['derivatives']:
formula = derivative['formula']
if re.search(IDENT_REGEX.format(ident), formula):
return True
return False
def uses_retain_variables(func):
return uses_ident(func, 'retain_variables')
def uses_single_grad(func):
return uses_ident(func, 'grad')