forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfunction_impl.h
157 lines (128 loc) · 4.56 KB
/
function_impl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#pragma once
#include <ATen/core/function.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/utils/memory.h>
namespace torch {
namespace jit {
struct TORCH_API GraphFunction : public Function {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GraphFunction(
c10::QualifiedName name,
std::shared_ptr<Graph> graph,
std::function<void(GraphFunction&)> function_creator)
: name_(std::move(name)),
graph_(std::move(graph)),
function_creator_(std::move(function_creator)) {}
bool isGraphFunction() const override {
return true;
}
void run(Stack& stack) override;
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch) override;
std::shared_ptr<Graph> graph() const {
return graph_;
}
std::shared_ptr<Graph> optimized_graph() const {
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
auto& optimized_graph = optimized_graphs_[currentSpecialization()];
if (optimized_graph) {
return *optimized_graph;
}
optimized_graph = graph_->copy();
if (getGraphExecutorOptimize()) {
preoptimizeGraph(*optimized_graph);
}
return *optimized_graph;
}
const c10::QualifiedName& qualname() const override {
return name_;
}
// if this isn't yet defined, run its method_creator function
void ensure_defined() override;
size_t num_inputs() const override {
return graph()->inputs().size();
}
Function& setSchema(FunctionSchema schema) override {
schema_ = make_unique<FunctionSchema>(std::move(schema));
return *this;
}
const FunctionSchema& getSchema() const override;
GraphExecutorState getDebugState() {
return get_executor().getDebugState();
}
bool is_optimized() const {
TORCH_WARN(
"GraphFunction::is_optimized() is deprecated and always returns true. "
"Please use getGraphExecutorOptimize()");
return true;
}
void check_single_output() {
TORCH_CHECK(
graph()->outputs().size() == 1,
"Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
}
GraphExecutor& get_executor() {
ensure_defined();
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
auto& executor = executors_[currentSpecialization()];
if (executor) {
return *executor;
}
check_single_output();
executor = GraphExecutor(optimized_graph(), name_.name());
return *executor;
}
bool call(
Stack& stack,
size_t bailOut,
c10::function_ref<void(const Code&)> f) override {
f(get_executor().getPlanFor(stack, bailOut).code);
return true;
}
private:
enum SpecializationKey {
AutocastOff,
CpuAutocastOn,
GpuAutocastOn,
CpuGpuAutocastOn,
// This provides the number of specializations
// (Must be last entry)
TotalCount
};
SpecializationKey currentSpecialization() const;
private:
c10::QualifiedName name_;
// The original, non-optimized graph
std::shared_ptr<Graph> graph_; // for debugging and for inlining
// Optimized graph, computed lazily. Used for inlining.
mutable std::array<
c10::optional<std::shared_ptr<Graph>>,
SpecializationKey::TotalCount>
optimized_graphs_;
// GraphFunctions are invokable from multiple threads, so this lock needs to
// be held when we're initializing graph executor for the first time or
// computing the optimized graph. We're using reentrant mutex so that we don't
// need to worry about causing a deadlock by calling one method from another
// (e.g. optimized_graph() from get_executor()).
mutable std::recursive_mutex compile_mutex;
// executor_[0] - autocast off
// executor_[1] - autocast on
std::array<c10::optional<GraphExecutor>, SpecializationKey::TotalCount>
executors_;
// an optional function that actually creates the method when
// ensure_defined() is called. This is used by the compiler so
// that it can construct methods out of order
std::function<void(GraphFunction&)> function_creator_;
// if absent, then we generate a default schema based on the graph
// mutable because getSchema caches the default schema if one is requested
// before a call to setSchema
mutable std::unique_ptr<FunctionSchema> schema_;
};
// Short hands for dynamic_cast<GraphFunction*>.
TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept;
TORCH_API GraphFunction& toGraphFunction(Function&);
TORCH_API const GraphFunction& toGraphFunction(const Function&);
} // namespace jit
} // namespace torch