forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
check_strict_fusion.cpp
132 lines (116 loc) · 3.84 KB
/
check_strict_fusion.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <torch/csrc/jit/passes/check_strict_fusion.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <unordered_map>
namespace torch {
namespace jit {
namespace {
bool isStrictFusion(Value* value) {
const auto class_name = getModuleName(value);
return class_name.has_value() &&
(*class_name == "__torch__.torch.jit.strict_fusion");
}
} // namespace
static bool fusionGuardCheck(Symbol k) {
return k == Symbol::prim("TensorExprDynamicGuard") || k == prim::TypeCheck ||
k == prim::CudaFusionGuard || k == prim::RequiresGradCheck;
}
static std::unordered_set<Node*> collectValuesUsedInGuard(
Node* guarding_if,
Node* enter_node) {
// DFS to collect
std::unordered_set<Node*> visited_nodes;
std::vector<Node*> queue = {guarding_if};
while (!queue.empty()) {
Node* curr = queue[queue.size() - 1];
queue.pop_back();
visited_nodes.insert(curr);
// these nodes directly test Tensor inputs, and are not part of additional
// guards inserted
if (fusionGuardCheck(curr->kind())) {
continue;
}
for (Value* v : curr->inputs()) {
Node* inp_node = v->node();
if (inp_node->isBefore(enter_node) ||
inp_node->owningBlock() != enter_node->owningBlock()) {
continue;
}
if (visited_nodes.count(inp_node)) {
continue;
}
queue.push_back(inp_node);
}
}
return visited_nodes;
}
static void checkForUnfusedOps(Node* enter_node) {
std::vector<Node*> unsupported_nodes;
std::vector<Node*> guarding_ifs; // if multiple, we will throw
for (Node* node = enter_node->next(); node->kind() != prim::Exit;
node = node->next()) {
if (node->kind() == prim::If &&
fusionGuardCheck(node->input()->node()->kind())) {
guarding_ifs.push_back(node);
continue;
}
unsupported_nodes.push_back(node);
}
if (guarding_ifs.size() > 1) {
std::stringstream ss;
ss << "Found multiple fusions: \n";
for (Node* n : guarding_ifs) {
ss << *n << "\n";
}
throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str();
}
// autodiff/nnc both insert a number of guards, see
// `CudaFusionViewGuard Example Graph`
// to check for unfused nodes, look at node's whose outputs
// are not depended on by the fusion guard
// restrict search for all values after the first
// node in the prim::Enter block
std::unordered_set<Node*> guarding_check_nodes;
if (guarding_ifs.size() == 1) {
guarding_check_nodes =
collectValuesUsedInGuard(guarding_ifs[0], enter_node);
}
std::vector<Node*> unfused_nodes_not_used_in_guard;
for (Node* unfused : unsupported_nodes) {
if (!guarding_check_nodes.count(unfused)) {
unfused_nodes_not_used_in_guard.push_back(unfused);
}
}
if (!unfused_nodes_not_used_in_guard.empty()) {
std::stringstream ss;
ss << "Found unfused operators: \n";
for (Node* unfused : unfused_nodes_not_used_in_guard) {
ss << "\t";
if (unfused->maybeSchema()) {
ss << unfused->schema();
} else {
unfused->kind().toDisplayString();
}
ss << "\n";
}
auto range = enter_node->input()->node()->sourceRange();
throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str();
}
}
void CheckStrictFusion(std::shared_ptr<Graph>& graph) {
DepthFirstGraphNodeIterator it(graph);
Node* n = nullptr;
while ((n = it.next()) != nullptr) {
if (n->kind() == prim::Enter && isStrictFusion(n->input())) {
checkForUnfusedOps(n);
}
}
// TODO: remove context manager after checks
// TODO: improve control flow not taken, right now always errors
}
} // namespace jit
} // namespace torch