forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
function_impl.h
194 lines (161 loc) · 5.97 KB
/
function_impl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#pragma once
#include <ATen/core/function.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
namespace torch::jit {
struct TORCH_API GraphFunction : public Function {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GraphFunction(
c10::QualifiedName name,
std::shared_ptr<Graph> graph,
std::function<void(GraphFunction&)> function_creator,
c10::optional<ExecutorExecutionMode> executor_execution_mode =
c10::nullopt)
: name_(std::move(name)),
graph_(std::move(graph)),
executor_execution_mode_(executor_execution_mode),
function_creator_(std::move(function_creator)) {}
bool isGraphFunction() const override {
return true;
}
void run(Stack& stack) override;
std::function<void(GraphFunction&)> function_creator() const {
return function_creator_;
}
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch) override;
std::shared_ptr<Graph> graph() const {
return graph_;
}
std::shared_ptr<Graph> optimized_graph() const {
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
auto& optimized_graph = optimized_graphs_[currentSpecialization()];
if (optimized_graph) {
return *optimized_graph;
}
optimized_graph = graph_->copy();
if (getGraphExecutorOptimize()) {
preoptimizeGraph(*optimized_graph, force_no_amp_);
}
return *optimized_graph;
}
const c10::QualifiedName& qualname() const override {
return name_;
}
// private/unstable api. sets the initial execution mode
// will not affect executor if there is an existing executor
// created for this function
void _set_initial_executor_execution_mode(ExecutorExecutionMode mode) {
executor_execution_mode_ = mode;
}
// private/unstable api. sets flag of whether or not to ignore amp.
// will not affect executor if there is an existing executor
// created for this function
void _set_ignore_amp(bool ignore_amp) {
force_no_amp_ = ignore_amp;
}
// if this isn't yet defined, run its method_creator function
void ensure_defined() override;
size_t num_inputs() const override {
return graph()->inputs().size();
}
Function& setSchema(FunctionSchema schema) override {
schema_ = std::make_unique<FunctionSchema>(std::move(schema));
return *this;
}
const FunctionSchema& getSchema() const override;
GraphExecutorState getDebugState() {
return get_executor().getDebugState();
}
bool is_optimized() const {
TORCH_WARN(
"GraphFunction::is_optimized() is deprecated and always returns true. "
"Please use getGraphExecutorOptimize()");
return true;
}
void check_single_output() {
TORCH_CHECK(
graph()->outputs().size() == 1,
"Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
}
GraphExecutor& get_executor() {
ensure_defined();
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
auto& executor = executors_[currentSpecialization()];
if (executor) {
return *executor;
}
check_single_output();
const std::string& name = name_.name();
std::shared_ptr<Graph> opt_graph = optimized_graph();
if (!executor_execution_mode_) {
executor = GraphExecutor(opt_graph, name);
} else {
executor = GraphExecutor(opt_graph, name, *executor_execution_mode_);
}
return *executor;
}
using Function::call;
bool call(
Stack& stack,
c10::optional<size_t> bailOut,
c10::function_ref<void(const Code&)> f) override {
f(get_executor().getPlanFor(stack, bailOut).code);
return true;
}
void clear_optimized_graphs() {
optimized_graphs_.fill(c10::nullopt);
}
private:
enum SpecializationKey {
AutocastOff,
CpuAutocastOn,
GpuAutocastOn,
CpuGpuAutocastOn,
// This provides the number of specializations
// (Must be last entry)
TotalCount
};
SpecializationKey currentSpecialization() const;
private:
c10::QualifiedName name_;
// The original, non-optimized graph
std::shared_ptr<Graph> graph_; // for debugging and for inlining
// allows users to specify Simple/Profiling Executor for function
// TODO: add more executors
mutable c10::optional<ExecutorExecutionMode> executor_execution_mode_;
// if invoked on a graph that has already traced through amp
// don't invoke amp pass
mutable bool force_no_amp_ = false;
// Optimized graph, computed lazily. Used for inlining.
mutable std::array<
c10::optional<std::shared_ptr<Graph>>,
SpecializationKey::TotalCount>
optimized_graphs_;
// GraphFunctions are invokable from multiple threads, so this lock needs to
// be held when we're initializing graph executor for the first time or
// computing the optimized graph. We're using reentrant mutex so that we don't
// need to worry about causing a deadlock by calling one method from another
// (e.g. optimized_graph() from get_executor()).
mutable std::recursive_mutex compile_mutex;
// executor_[0] - autocast off
// executor_[1] - autocast cpu on
// executor_[2] - autocast gpu on
// executor_[3] - autocast cpu & gpu on
std::array<c10::optional<GraphExecutor>, SpecializationKey::TotalCount>
executors_;
// an optional function that actually creates the method when
// ensure_defined() is called. This is used by the compiler so
// that it can construct methods out of order
std::function<void(GraphFunction&)> function_creator_;
// if absent, then we generate a default schema based on the graph
// mutable because getSchema caches the default schema if one is requested
// before a call to setSchema
mutable std::unique_ptr<FunctionSchema> schema_;
};
// Short hands for dynamic_cast<GraphFunction*>.
TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept;
TORCH_API GraphFunction& toGraphFunction(Function&);
TORCH_API const GraphFunction& toGraphFunction(const Function&);
} // namespace torch::jit