forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_iterator.h
147 lines (130 loc) · 4.82 KB
/
graph_iterator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <torch/csrc/jit/ir/ir.h>
namespace torch::jit {
// This class facilitates depth-first iteration over all nodes in a graph.
class DepthFirstGraphNodeIterator {
Node* current_;
public:
// Constructor.
explicit DepthFirstGraphNodeIterator(std::shared_ptr<Graph>& graph)
: current_(*(graph->block()->nodes().begin())) {}
// Moves up and to the next node (may move up recursively).
void move_up() {
if (current_ == nullptr) {
return;
}
// Basically we start from the child block (which is current_)
// and we try to find the block that owns it. Now we need to check
// if that block is the graph root block, or if it is an If/Loop/etc
// block.
//
// If it's the graph root block we can stop because there is no "up"
// but if it is a node (e.g. If/Loop/etc) we need to apply logic
// based on where we are coming from to move to the next block.
// This might mean that we need to traverse up again (e.g. if we've
// reached the end of the else clause in an if block we need to go)
// up to the parent block that contains the if.
//
// Similarly if we've reached the end of the parent block containing
// the else clause we might need to go up again so this is a recursive
// function.
//
// BlockNode (if/loop/with)
// |
// [Block1] ... [Block2]
// |
// [ Node1, Node2, Node3, FromNode]
//
auto parent_block = current_->owningBlock();
TORCH_INTERNAL_ASSERT(parent_block, "Every node must be owned by a block");
// Get the node that owns the parent block. This node has to be an if,
// loop, or with.
auto parent_node = parent_block->owningNode();
if (parent_node == nullptr) {
// If there's no node that owns this current block then we're at the
// top of the graph and since we're trying to move up we have reached
// the end of the traversal.
current_ = nullptr;
return;
}
// Check the type of node this root is.
if (parent_node->kind() == prim::If) {
// Need to check if we came from the `then` branch or the `else` branch.
auto* then_block = parent_node->blocks().at(0);
auto* else_block = parent_node->blocks().at(1);
if (parent_block == else_block) {
// If else block then we move to the next node in the parent block.
current_ = parent_node->next();
if (current_->kind() == prim::Return) {
move_up();
}
} else {
// If then block then move to the else block if it is not empty.
TORCH_INTERNAL_ASSERT(parent_block == then_block);
bool else_block_empty =
else_block->nodes().begin() == else_block->nodes().end();
if (!else_block_empty) {
current_ = *(else_block->nodes().begin());
} else {
// Since it's empty we move to the next node.
current_ = parent_node->next();
if (current_->kind() == prim::Return) {
move_up();
}
}
}
} else if (
parent_node->kind() == prim::Loop ||
parent_node->kind() == prim::With) {
current_ = parent_node->next();
if (current_->kind() == prim::Return) {
move_up();
}
} else {
TORCH_INTERNAL_ASSERT(
false, "Only if/loop/with nodes should have child blocks");
}
}
// Moves to the next adjacent node or up in to the parent if that is not
// possible.
void move_next() {
if (current_ == nullptr) {
return;
}
// Increment to the next node in the current block.
current_ = current_->next();
// Check if we're at the end of the block. If so we need
// to move upwards (if it makes sense to).
if (current_->kind() == prim::Return) {
move_up();
}
}
// Moves to the next node in the graph into children if it can.
void move_into() {
if (current_ == nullptr) {
return;
}
// Check if we're currently on a node that contains sub-nodes.
if (current_->kind() == prim::If || current_->kind() == prim::Loop ||
current_->kind() == prim::With) {
auto* first_block = current_->blocks().at(0);
current_ = first_block->param_node();
// Move next will move up and out of the current node if the block is
// empty. `move_up` which is called by `move_next` will handle the
// difference between If, Loop, and With blocks appropriately.
move_next();
} else {
move_next();
}
}
// Get the next Node in the graph. \returns nullptr if there are no nodes
// left.
Node* next() {
auto result = current_;
// Try move into the existing node to set the next node to be returned.
// This will move to the next node if not possible, or move upwards and
// to the next.
move_into();
return result;
}
};
} // namespace torch::jit