forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
batch_mm.cpp
506 lines (469 loc) · 18.1 KB
/
batch_mm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
#include <torch/csrc/jit/passes/batch_mm.h>
#include <ATen/core/functional.h>
#include <ATen/core/symbol.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <ATen/ATen.h>
#include <algorithm>
#include <unordered_map>
#include <utility>
namespace torch {
namespace jit {
namespace {
c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() {
return AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
}
} // namespace
// This pass looks for trees in the graph, where leaves are mm ops, and the
// inner vertices are add nodes. Once we have such a tree they can be reduced to
// two concats and a single mm (basically into a single multiply of a wide
// matrix, with a tall matrix). Such patterns show up mostly in backward of
// RNNs, since the derivative of many uses of matrix multiplies with same
// weights forms exactly such a tree (note that it's usually also highly
// imbalanced i.e. has O(n) depth).
//
// This (or any tree of adds of MMs):
//
// +------+ +------+ +------+ +------+ +------+
// | | | | | | | | | |
// | L1 | | R1 | + | L2 | | R2 | = | O |
// | | | | | | | | | |
// +------+ +------+ +------+ +------+ +------+
//
// can be basically transformed into a single MM which looks like this
// (we concat all lhs operands, concat rhs operands, do mm):
//
// +------+
// | |
// | R1 |
// | |
// +------+
// | |
// | R2 |
// | |
// +------+
// +------+------+ +------+
// | | | | |
// | L1 | L2 | | O |
// | | | | |
// +------+------+ +------+
// Note [Further optimizations]
// It would be straightforward to extend the TreeToken class to also detect if
// all MMs had the same lhs/rhs. In such case it's more efficient to expand the
// lhs and use bmm + sum instead of repeating it in memory via concat.
// Note [Overlapping trees]
// Additionally it wouldn't be too hard to add support for partially overlapping
// trees. Right now the it's forbidden in the algorithm (only a single tree will
// be allowed), so theoretically we might miss some optimization options,
// especially that the rejected tree could be much larger. I didn't implement
// that because it's not necessary for the simple RNN cases I saw, so I decided
// to keep stuff simple. If we ever get around implementing this, the right
// solution is probably to fuse MMs for the common part, and assume it's an
// input leaf for the outer two parts (I don't think it's beneficial to
// recompute, unless the subtree is super small, but let's not get into such
// details).
// The algorithm we're using is simple. We're iterating through the graph in the
// topological order and labeling nodes with TreeTokens. Then, we look for roots
// of the trees we formed and fuse them.
// Tunable parameter. Set to something larger if it turns out to be better.
static constexpr size_t min_fusion_size = 4;
static bool have_same_shape(at::TensorList inputs) {
auto expected_sizes = inputs[0].sizes();
return (std::all_of(
inputs.begin(), inputs.end(), [expected_sizes](const at::Tensor& t) {
return t.sizes() == expected_sizes;
}));
}
static bool should_be_transposed(at::TensorList inputs) {
return (std::all_of(inputs.begin(), inputs.end(), [](const at::Tensor& t) {
return t.stride(0) == 1 && t.stride(1) == t.size(0);
}));
}
static std::vector<at::Tensor> transpose_inputs(at::TensorList inputs) {
return fmap(inputs, [](const at::Tensor& i) { return i.t(); });
}
static bool shape_is_fast_for_reduce(
const at::Tensor& lhs,
const at::Tensor& rhs) {
size_t l = lhs.size(0);
size_t m = lhs.size(1);
size_t r = rhs.size(1);
// Numbers obtained by some simple benchmarks of fp32 gemms on a TITAN V
return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256));
}
RegisterOperators mm_tree_reduction_reg({Operator(
"prim::MMTreeReduce(...) -> Tensor",
[](Stack& stack) {
auto num_inputs = pop(stack).toInt();
std::vector<at::Tensor> inputs;
inputs.reserve(num_inputs);
for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
inputs.push_back(std::move(*it).toTensor());
}
drop(stack, num_inputs);
AT_ASSERT(!inputs.empty());
AT_ASSERT(inputs.size() % 2 == 0);
size_t side_num_elems = inputs.size() / 2;
auto lhs_inputs = at::TensorList(inputs).slice(0, side_num_elems);
auto rhs_inputs = at::TensorList(inputs).slice(side_num_elems);
// TODO: checking this is not free, so we should stop if this keeps
// failing
if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) &&
shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
// sometimes lhs_inputs or rhs_inputs are not contiguous, and that
// causes at::cat to go through slow path view them as contiguous if
// possible by transposing
bool lhs_input_transposed = should_be_transposed(lhs_inputs);
bool rhs_input_transposed = should_be_transposed(rhs_inputs);
at::Tensor lhs, rhs;
if (lhs_input_transposed) {
std::vector<at::Tensor> lhs_contig_inputs =
transpose_inputs(lhs_inputs);
lhs = at::cat(lhs_contig_inputs, /*dim*/ 0);
lhs = lhs.t();
} else {
lhs = at::cat(lhs_inputs, /*dim=*/1);
}
if (rhs_input_transposed) {
std::vector<at::Tensor> rhs_contig_inputs =
transpose_inputs(rhs_inputs);
rhs = at::cat(rhs_contig_inputs, /*dim*/ 1);
rhs = rhs.t();
} else {
rhs = at::cat(rhs_inputs, /*dim=*/0);
}
push(stack, at::mm(lhs, rhs));
} else {
auto acc = at::mm(inputs[0], inputs[side_num_elems]);
for (const auto i : c10::irange(1, side_num_elems)) {
acc.add_(at::mm(inputs[i], inputs[side_num_elems + i]));
}
push(stack, std::move(acc));
}
},
aliasAnalysisIsSpecialCase())});
// TreeTokens will be used to label nodes of the graph, if the nodes will fit
// our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
// when we reach node N with inputs A and B, then A and B have already been
// processed, and we can try to unify their TreeTokens (if they have them)
// and build a larger tree.
struct TreeToken {
uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops
Node* node = nullptr;
bool is_root = false;
static TreeToken mm(Node* mm) {
TreeToken token;
token.tree_size = 1;
token.node = mm;
token.is_root = true;
return token;
}
// NB: the returned token might be invalid, so make sure to check its boolean
// value!
static TreeToken transpose(Node* t, TreeToken& inp_token) {
TreeToken token;
if (!inp_token.node->matches(
"aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
return token;
}
token.tree_size = 1;
token.node = t;
token.is_root = true;
inp_token.is_root = false;
return token;
}
// NB: the returned token might be invalid, so make sure to check its boolean
// value!
static TreeToken add(Node* add, TreeToken& l, TreeToken& r) {
TreeToken token;
// See Note [Overlapping trees]
if (&l == &r || !l.is_root || !r.is_root)
return token;
token.tree_size = l.tree_size + r.tree_size;
token.node = add;
token.is_root = true;
l.is_root = r.is_root =
false; // Reserve the subtrees, so they can't be used again.
return token;
}
explicit operator bool() {
return is_root;
}
std::vector<Node*> removeTransposesAndGatherMatmuls() {
std::vector<Node*> matmuls;
std::vector<Node*> queue{node};
Graph* graph = node->owningGraph();
while (!queue.empty()) {
auto n = queue.back();
queue.pop_back();
if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
matmuls.push_back(n);
} else if (n->matches("aten::t(Tensor self) -> Tensor")) {
Node* input_node = n->input()->node();
AT_ASSERT(input_node->matches(
"aten::mm(Tensor self, Tensor mat2) -> Tensor"));
// (AB)^T == B^TA^T
WithInsertPoint insert_guard{input_node};
Value* A = input_node->inputs()[0];
Value* B = input_node->inputs()[1];
Value* AT = graph->insert(aten::t, {A});
Value* BT = graph->insert(aten::t, {B});
Value* BTAT = graph->insert(aten::mm, {BT, AT});
n->output()->replaceAllUsesWith(BTAT);
matmuls.push_back(BTAT->node());
} else if (
n->matches(
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
queue.push_back(n->inputs()[0]->node());
queue.push_back(n->inputs()[1]->node());
} else {
AT_ASSERTM(false, "Unsupported node found in a BatchMM tree!");
}
}
return matmuls;
}
};
enum class Side { LHS, RHS };
static void BatchMMTreeReduce(Block* block, AliasDb& alias_db) {
auto graph = block->owningGraph();
// Look for trees in the block
std::unordered_map<Node*, TreeToken> tokens;
for (auto node : block->nodes()) {
if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
!alias_db.hasWriters(node)) {
tokens[node] = TreeToken::mm(node);
} else if (
node->matches("aten::t(Tensor self) -> Tensor") &&
!alias_db.hasWriters(node)) {
auto input_it = tokens.find(node->input()->node());
if (input_it != tokens.end()) {
tokens[node] = TreeToken::transpose(node, input_it->second);
}
} else if (
node->matches(
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") &&
!alias_db.hasWriters(node)) {
Node* lhs = node->inputs()[0]->node();
Node* rhs = node->inputs()[1]->node();
auto lhs_it = tokens.find(lhs);
auto rhs_it = tokens.find(rhs);
// See Note [Overlapping trees] (regarding the uses().size() == 1 check)
// We could treat a subtree with multiple uses as if it was overlapping.
// XXX: uses().size() == 1 is also something that guarantees that this
// transform is valid, because we know for sure that the none of these
// operands depend on the result of the other. If we were to remove this,
// we need to compute a transitive closure and actually check the
// dependencies.
if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
lhs->output()->uses().size() == 1 &&
rhs->output()->uses().size() == 1) {
if (auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) {
tokens[node] = token;
}
}
} else {
for (auto block : node->blocks()) {
BatchMMTreeReduce(block, alias_db);
}
}
}
// Merge trees we've found
for (auto& item : tokens) {
auto& root = item.second;
if (!root || root.tree_size < min_fusion_size)
continue;
auto matmuls = root.removeTransposesAndGatherMatmuls();
WithInsertPoint insert_guard{root.node};
Node* tree_reduce =
graph->insertNode(graph->create(Symbol::prim("MMTreeReduce")));
for (Node* matmul : matmuls) {
tree_reduce->addInput(matmul->inputs().at(0));
}
for (Node* matmul : matmuls) {
tree_reduce->addInput(matmul->inputs().at(1));
}
root.node->output()->replaceAllUsesWith(tree_reduce->output());
// NB: don't bother with cleaning up after yourself. We'll use DCE for that.
}
}
static bool shape_is_fast_for_side(const at::Tensor& other_side_input) {
// Cutoff chosed by benchmarking on a TITAN V
return other_side_input.numel() <= 1024 * 2048;
}
RegisterOperators mm_batch_side_reg({Operator(
prim::MMBatchSide,
[](const Node* node) -> Operation {
size_t num_other_side_inputs = node->inputs().size() - 1;
Side single_side = static_cast<Side>(node->i(Symbol::attr("side")));
return [num_other_side_inputs, single_side](Stack& stack) {
at::Tensor side_input;
std::vector<at::Tensor> other_side_inputs;
other_side_inputs.reserve(num_other_side_inputs);
for (auto it = stack.end() - num_other_side_inputs; it != stack.end();
++it) {
other_side_inputs.push_back(std::move(*it).toTensor());
}
drop(stack, num_other_side_inputs);
pop(stack, side_input);
auto any_other_input = other_side_inputs[0];
if (have_same_shape(other_side_inputs) &&
shape_is_fast_for_side(other_side_inputs[0])) {
auto other_side_input =
at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0);
auto mm_out = single_side == Side::LHS
? side_input.mm(other_side_input)
: other_side_input.mm(side_input);
auto outputs = at::chunk(
mm_out,
num_other_side_inputs,
/*dim=*/single_side == Side::LHS ? 1 : 0);
stack.insert(
stack.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
} else {
if (single_side == Side::LHS) {
for (at::Tensor& other : other_side_inputs) {
stack.emplace_back(side_input.mm(other));
}
} else {
for (at::Tensor& other : other_side_inputs) {
stack.emplace_back(other.mm(side_input));
}
}
}
};
},
aliasAnalysisIsSpecialCase())});
static std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
Value* value,
AliasDb& alias_db) {
const auto postprocess = [&](std::vector<Node*> mms) {
if (mms.empty()) {
return mms;
}
std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) {
return n->isBefore(m);
});
// Filter out dependent MMs. This algorithm might do very badly if e.g. you
// have a lot of independent MMs, that depend on the first one, but I doubt
// this will be a common scenario.
for (const auto i : c10::irange(mms.size())) {
if (mms[i] == nullptr)
continue;
for (size_t j = i + 1; j < mms.size(); ++j) {
if (mms[j] == nullptr)
continue;
if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) {
mms[j] = nullptr;
}
}
}
return c10::filter(mms, [](Node* n) { return n != nullptr; });
};
Block* block = value->node()->owningBlock();
std::vector<Node*> lhses; // Will contain nodes where value is used as an lhs
std::vector<Node*> rhses; // Like above, but rhs
for (Use u : value->uses()) {
if (u.user->owningBlock() == block &&
u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
!alias_db.hasWriters(u.user)) {
if (u.offset == 0 && u.user->inputs()[1] != value) {
lhses.push_back(u.user);
} else if (u.offset == 1 && u.user->inputs()[0] != value) {
rhses.push_back(u.user);
}
}
}
return std::make_pair(
postprocess(std::move(lhses)), postprocess(std::move(rhses)));
}
static void BatchMMSide(Block* block, AliasDb& alias_db) {
// NB: 8 is the current loop unrolling factor
static constexpr size_t how_many_is_many = 8;
const auto batch_side = [&](std::vector<Node*>& mms, Side side) {
AT_ASSERT(!mms.empty());
for (int64_t i = static_cast<int64_t>(mms.size()) - 2; i >= 0; --i) {
bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]);
AT_ASSERT(move_ok);
}
WithInsertPoint insert_guard{mms[0]};
Graph* graph = mms[0]->owningGraph();
Node* batch_mm = graph->create(
prim::MMBatchSide,
/*inputs=*/{},
/*num_outputs=*/mms.size());
graph->insertNode(batch_mm);
batch_mm->i_(Symbol::attr("side"), static_cast<int>(side));
Value* const_side = mms[0]->inputs().at(side == Side::LHS ? 0 : 1);
batch_mm->addInput(const_side);
for (const auto i : c10::irange(mms.size())) {
batch_mm->addInput(mms[i]->inputs().at(side == Side::LHS ? 1 : 0));
mms[i]->output()->replaceAllUsesWith(batch_mm->outputs().at(i));
}
};
std::unordered_set<Value*> considered_values;
for (Node* node : block->nodes()) {
if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor") &&
!alias_db.hasWriters(node)) {
for (Value* input : node->inputs()) {
if (/*bool not_inserted = */ !considered_values.emplace(input).second) {
continue;
}
auto uses_with_many = gatherIndependentMMUses(input, alias_db);
if (uses_with_many.first.size() >= how_many_is_many) {
batch_side(uses_with_many.first, Side::LHS);
}
if (uses_with_many.second.size() >= how_many_is_many) {
batch_side(uses_with_many.second, Side::RHS);
}
}
} else {
for (Block* subblock : node->blocks()) {
BatchMMSide(subblock, alias_db);
}
}
}
}
static bool hasMutableOperators(Block* block) {
for (auto n : block->nodes()) {
if (n->kind().is_aten() && n->schema().is_mutable())
return true;
for (auto b : n->blocks()) {
if (hasMutableOperators(b))
return true;
}
}
return false;
}
static bool hasMMOperators(std::shared_ptr<Graph>& graph) {
DepthFirstGraphNodeIterator it(graph);
Node* n = nullptr;
while ((n = it.next()) != nullptr) {
if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
return true;
}
}
return false;
}
void BatchMM(std::shared_ptr<Graph>& graph) {
if (!hasMMOperators(graph)) {
return;
}
AliasDb alias_db(graph);
BatchMMTreeReduce(graph->block(), alias_db);
BatchMMSide(graph->block(), alias_db);
EliminateDeadCode(graph);
// It's possible that transpose rearrangements have created sequences of
// consecutive transposes that didn't exist before.
// tensor type properties are not guaranteed to be correct
PeepholeOptimize(graph, /*disable_shape_peepholes*/ true);
}
} // namespace jit
} // namespace torch