forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathargument_spec.cpp
291 lines (280 loc) · 10.3 KB
/
argument_spec.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
#include <c10/util/irange.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <iostream>
namespace torch {
namespace jit {
void ArgumentSpecCreator::scan(
const TypePtr& typ,
size_t depth,
const WrittenSlots& written_slots) {
auto finishAggregate = [&](size_t pos) {
// it is possible after all the work we did to scan this aggregate,
// we found no tensors or optionals to specialize. In this case, just
// generate a skip for the whole aggregate.
bool any_spec = std::any_of(
instructions_.begin() + pos, instructions_.end(), [](Inst i) {
return i == SPECIALIZE_TENSOR || i == SPECIALIZE_OPTIONAL ||
i == SPECIALIZE_OPTIONAL_TENSOR;
});
if (!any_spec) {
instructions_[pos] = SKIP;
instructions_.resize(pos + 1);
} else {
instructions_.emplace_back(LEAVE);
}
};
// the simple vm that scans instructions_ has a limited stack depth,
// this prevents going deeper than that.
if (depth >= ARG_SPEC_DEPTH_LIMIT) {
instructions_.emplace_back(SKIP);
}
if (typ->isSubtypeOf(*TensorType::get())) {
num_tensors_++;
instructions_.emplace_back(SPECIALIZE_TENSOR);
} else if (typ->isSubtypeOf(*OptionalType::ofTensor())) {
num_tensors_++;
num_optionals_++;
instructions_.emplace_back(SPECIALIZE_OPTIONAL_TENSOR);
} else if (typ->kind() == TypeKind::OptionalType) {
// note that Optional[Tuple] or Optional[Class] will just register
// as optional (previously they didn't at all, so it's not a regression).
num_optionals_++;
instructions_.emplace_back(SPECIALIZE_OPTIONAL);
} else if (auto tup = typ->cast<TupleType>()) {
size_t pos = instructions_.size();
instructions_.emplace_back(ENTER_TUPLE);
for (const auto& elem : tup->containedTypes()) {
scan(elem, depth + 1, written_slots);
}
finishAggregate(pos);
} else if (auto cls = typ->cast<ClassType>()) {
size_t pos = instructions_.size();
instructions_.emplace_back(ENTER_OBJECT);
for (size_t i = 0; i < cls->numAttributes(); ++i) {
auto key =
cls->name()->qualifiedName() + cls->getAttributes().at(i).getName();
// it is only safe to specialize because someone might have written to it
if (!written_slots.count(key)) {
scan(cls->containedTypes().at(i), depth + 1, written_slots);
} else {
instructions_.emplace_back(SKIP);
}
}
finishAggregate(pos);
} else {
instructions_.emplace_back(SKIP);
}
};
// this is a coarse-grained guarantee that the slots of a class will not be
// modified by the function. It works fine for things that used be read-only
// modules, but will be overly conservative when some classes are written to.
// Doing alias analysis and looking for writes to the class would be more
// accurate.
static void scanWrittenSlots(
Block* block,
ArgumentSpecCreator::WrittenSlots& written_slots) {
for (Node* n : block->nodes()) {
if (n->kind() == prim::SetAttr) {
if (auto cls = n->inputs().at(0)->type()->cast<ClassType>()) {
written_slots.insert(cls->name()->qualifiedName() + n->s(attr::name));
}
}
for (Block* subblock : n->blocks()) {
scanWrittenSlots(subblock, written_slots);
}
if (n->hasAttribute(attr::Subgraph)) {
scanWrittenSlots(n->g(attr::Subgraph)->block(), written_slots);
}
}
}
ArgumentSpecCreator::ArgumentSpecCreator(Graph& graph)
: num_inputs_(graph.inputs().size()) {
WrittenSlots written_slots;
scanWrittenSlots(graph.block(), written_slots);
for (Value* input : graph.inputs()) {
scan(input->type(), 0, written_slots);
}
}
void ArgumentSpecCreator::dump() const {
for (Inst inst : instructions_) {
switch (inst) {
case LEAVE:
std::cout << "] ";
break;
case ENTER_TUPLE:
std::cout << "Tuple[";
break;
case ENTER_OBJECT:
std::cout << "Object[";
break;
case SKIP:
std::cout << "Skip ";
break;
case SPECIALIZE_TENSOR:
std::cout << "SpecializeTensor ";
break;
case SPECIALIZE_OPTIONAL_TENSOR:
std::cout << "SpecializeOptionalTensor ";
break;
case SPECIALIZE_OPTIONAL:
std::cout << "SpecializeOptional ";
break;
}
}
std::cout << "\n";
}
ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input)
const {
ArgumentSpec spec(num_tensors_, num_optionals_);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
const IValue* stack[ARG_SPEC_DEPTH_LIMIT]; // The stack of IValue lists
// The stack gets initialized with the input list
stack[0] = last(input, num_inputs_).begin();
size_t stack_top = 0; // offset to the top of the stack
for (Inst inst : instructions_) {
switch (inst) {
case SPECIALIZE_OPTIONAL_TENSOR: {
// consume a tensor optional and add to the argspec
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
auto& arg = *stack[stack_top]++;
spec.addOptional(arg);
if (!arg.isNone()) {
spec.addTensor(arg, with_grad);
}
} break;
case SPECIALIZE_TENSOR:
// consume a tensor and add to the argspec
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
spec.addTensor(*stack[stack_top]++, with_grad);
break;
case SPECIALIZE_OPTIONAL:
// consume a non-tensor optional and add to the argspec
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
spec.addOptional(*stack[stack_top]++);
break;
case ENTER_TUPLE: {
// consume tuple
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
const IValue* iv = stack[stack_top]++;
AT_ASSERT(iv->isTuple(), "Expected Tuple but got ", iv->tagKind());
auto p = *reinterpret_cast<const at::ivalue::Tuple* const*>(iv);
auto tup_ptr = &p->elements()[0];
// push list of tuple elements to the stack
stack[++stack_top] = tup_ptr;
} break;
case ENTER_OBJECT: {
// consume object
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
const IValue* iv = stack[stack_top]++;
AT_ASSERT(iv->isObject(), "Expected Object but got ", iv->tagKind());
auto obj_ptr = &iv->toObjectRef().slots()[0];
// push list of object elements to the stack
stack[++stack_top] = obj_ptr;
} break;
case SKIP:
// consume and skip an element
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Assign)
stack[stack_top]++;
break;
case LEAVE:
--stack_top;
break;
}
}
return spec;
}
// For every input of a given graph, returns a most detailed type that can be
// inferred for it based on this ArgumentSpec.
void ArgumentSpecCreator::specializeTypes(
Graph& graph,
const ArgumentSpec& spec) const {
auto input_types =
fmap(graph.inputs(), [](Value* input) { return input->type(); });
std::vector<std::vector<TypePtr>> result_stack;
result_stack.emplace_back();
std::vector<const TypePtr*> input_stack = {input_types.data()};
std::vector<std::function<TypePtr()>> aggregate_creators;
size_t tensor_arg_spec_offset =
0; // number of specialized tensors seen so far
size_t optional_arg_spec_offset =
0; // number of specialized optionals seen so far
for (Inst inst : instructions_) {
switch (inst) {
case SPECIALIZE_OPTIONAL_TENSOR: {
auto& input_type = *input_stack.back()++;
auto is_present = spec.isPresent(optional_arg_spec_offset++);
if (!is_present) {
result_stack.back().emplace_back(input_type);
break;
}
auto& arg = spec.tensorAt(tensor_arg_spec_offset++);
AT_ASSERT(arg.defined());
result_stack.back().emplace_back(arg.toType());
} break;
case SPECIALIZE_TENSOR: {
input_stack.back()++;
auto& arg = spec.tensorAt(tensor_arg_spec_offset++);
if (!arg.defined()) {
result_stack.back().emplace_back(TensorType::get()->withUndefined());
} else {
result_stack.back().emplace_back(arg.toType());
}
} break;
case SPECIALIZE_OPTIONAL: {
auto is_present = spec.isPresent(optional_arg_spec_offset++);
auto ot = (*input_stack.back()++)->expect<OptionalType>();
if (!is_present) {
result_stack.back().emplace_back(ot);
} else {
result_stack.back().emplace_back(ot->getElementType());
}
} break;
case ENTER_TUPLE: {
auto tup = (*input_stack.back()++)->expect<TupleType>();
input_stack.emplace_back(tup->elements().data());
result_stack.emplace_back();
aggregate_creators.emplace_back(
[&] { return TupleType::create(result_stack.back()); });
} break;
case ENTER_OBJECT: {
auto cls = (*input_stack.back()++)->expect<ClassType>();
input_stack.emplace_back(cls->containedTypes().data());
result_stack.emplace_back();
aggregate_creators.emplace_back(
[&result_stack, cls] { return cls->refine(result_stack.back()); });
} break;
case SKIP:
result_stack.back().emplace_back(*input_stack.back()++);
break;
case LEAVE:
TypePtr result = aggregate_creators.back()();
result_stack.pop_back();
aggregate_creators.pop_back();
input_stack.pop_back();
result_stack.back().emplace_back(std::move(result));
break;
}
}
AT_ASSERT(result_stack.size() == 1);
// FIXME: by doing this only on the inputs, we only capture graph inputs and
// not
// optionals in tuples or objects. For that to work, we would have
// to investigate the uses of the inputs in detail to change the
// accesses/ unwrapping
auto inputs = graph.inputs();
for (const auto i : c10::irange(inputs.size())) {
auto t = result_stack.back()[i];
if (auto ot = t->cast<OptionalType>()) {
// if an optional input hasn't been specialized above, it is None
// so we disconnect the input here and replace its uses with
// a constant
WithInsertPoint guard(*graph.nodes().begin());
auto c = graph.insertConstant({});
inputs[i]->replaceAllUsesWith(c);
} else {
inputs[i]->setType(t);
}
}
}
} // namespace jit
} // namespace torch