forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlower_graph.cpp
157 lines (142 loc) · 5.08 KB
/
lower_graph.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#include <torch/csrc/jit/passes/lower_graph.h>
#include <torch/csrc/jit/api/object.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/custom_class.h>
#include <unordered_map>
namespace torch {
namespace jit {
struct Slot {
c10::intrusive_ptr<c10::ivalue::Object> obj;
size_t offset;
bool operator==(const Slot& other) const {
return (this->obj == other.obj && this->offset == other.offset);
}
};
// remove the first module argument, replacing any access of its
// parameters/attributes with extra_ivalue input Slots that hold what value to
// pass into the graph. Used for ONNX export to remove first-class modules
// so it can deal purely with parameters and inputs
std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
const ModulePtr& self,
Graph& g_,
size_t self_offset = 0) {
std::shared_ptr<Graph> g = g_.copy();
// Inline to remove method/function calls
Inline(*g);
std::vector<Slot> extra_ivalues;
struct SlotHash {
std::size_t operator()(const Slot& slot) const {
auto obj_hash = std::hash<c10::ivalue::Object*>{}(slot.obj.get());
auto offset_hash = std::hash<size_t>{}(slot.offset);
return c10::hash_combine(obj_hash, offset_hash);
}
};
std::unordered_map<Slot, size_t, SlotHash> slot_to_offset;
struct ToScan {
ModulePtr mod;
Node* n;
size_t offset;
};
std::vector<ToScan> to_scan;
std::vector<Node*> to_clean; // nodes that should be dead at the end
auto getOrAddSlot = [&](const Slot& slot) -> Value* {
auto it = slot_to_offset.find(slot);
if (it != slot_to_offset.end()) {
size_t ivalues_start = g->inputs().size() - extra_ivalues.size();
return g->inputs().at(ivalues_start + it->second);
}
extra_ivalues.emplace_back(slot);
slot_to_offset[slot] = extra_ivalues.size() - 1;
return g->addInput()->setType(slot.obj->getSlot(slot.offset).type());
};
auto self_value = g->inputs().at(self_offset);
for (Use use : self_value->uses()) {
to_scan.emplace_back(ToScan{self, use.user, use.offset});
}
while (to_scan.size() > 0) {
auto e = to_scan.back();
to_scan.pop_back();
// when we lambda lift forks, first-class modules may be passed across
// forks. This code recursively lowers the module in the fork call.
if (e.n->kind() == prim::fork) {
auto subgraph = e.n->g(attr::Subgraph);
std::vector<Slot> new_slots;
std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset);
e.n->g_(attr::Subgraph, subgraph);
for (const Slot& slot : new_slots) {
e.n->addInput(getOrAddSlot(slot));
}
e.n->removeInput(e.offset);
continue;
}
if (e.n->kind() == prim::PythonOp) {
throw ErrorReport(e.n->sourceRange()) << "Couldn't export Python method.";
}
if (e.n->kind() != prim::GetAttr) {
throw ErrorReport(e.n->sourceRange())
<< "temporary: the only valid use of a module is looking up an "
"attribute but found "
<< *e.n;
}
size_t slot_idx = e.mod->type()->getAttributeSlot(e.n->s(attr::name));
auto iv = e.mod->getSlot(slot_idx);
if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) {
if (c->is_module()) {
for (Use use : e.n->output()->uses()) {
to_scan.emplace_back(ToScan{iv.toObject(), use.user, use.offset});
}
to_clean.emplace_back(e.n);
continue;
}
}
e.n->output()->replaceAllUsesWith(getOrAddSlot({e.mod, slot_idx}));
e.n->destroy();
}
while (to_clean.size() > 0) {
Node* n = to_clean.back();
AT_ASSERT(!n->hasUses());
n->destroy();
to_clean.pop_back();
}
AT_ASSERT(!self_value->hasUses());
g->eraseInput(self_offset);
return std::make_pair(std::move(g), std::move(extra_ivalues));
}
static std::vector<IValue> loadTensors(const std::vector<Slot>& slots) {
std::vector<IValue> result;
result.reserve(slots.size());
for (const Slot& slot : slots) {
auto obj = slot.obj->getSlot(slot.offset);
if (obj.isTensor()) {
result.emplace_back(obj.toTensor());
} else {
// Unpack quantization packed tensor
auto type = obj.type();
TORCH_CHECK(
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) ||
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) ||
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.LinearPackedParamsBase")),
"Unknown type ",
type->repr_str(),
" encountered in graph lowering. This type is not supported in ONNX export.");
result.emplace_back(
script::Object(obj.toObject()).run_method("__getstate__"));
}
}
return result;
}
std::pair<std::shared_ptr<Graph>, std::vector<IValue>> LowerGraph(
Graph& graph,
const ModulePtr& self) {
auto result = lower_graph(self, graph);
return std::make_pair(result.first, loadTensors(result.second));
}
} // namespace jit
} // namespace torch