forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fixup_trace_scope_blocks.h
47 lines (43 loc) · 1.63 KB
/
fixup_trace_scope_blocks.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#pragma once
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
// Directly after tracing, we have an ill-formed graph with blocks inserted.
// Example:
//
// graph(%self : ClassType<Module>,
// %input.1 : Float(3, 4)):
// %1 : ClassType<Module> = prim::GetAttr[name="relu1"](%self)
// %2 : ClassType<Module> = prim::GetAttr[name="relu2"](%self)
// %3 : ClassType<Module> = prim::GetAttr[name="rrr"](%2)
// = prim::TracedModuleForward[scope="__module.relu1"]()
// block0():
// %input : Float(3, 4) = aten::relu(%input.1),
// -> ()
// = prim::TracedModuleForward[scope="__module.relu2"](),
// block0():
// = prim::TracedModuleForward[scope="__module.relu2.rrr"](),
// block0():
// %6 : Float(3, 4) = aten::relu(%input),
// -> ()
// -> ()
// return (%6)
//
// In this pass, we:
// 1) Lift Value defs to as high of a scope as needed to ensure that
// they dominate all their uses. For example, `input` in the above
// graph needs to be lifted to the top-level block so that its use
// in the second `relu` operator is dominated.
// 2) Lambda lift the blocks. This ensures that all values used within
// each scope have their defs captured.
// 3) Convert the scope blocks into methods on their respective Modules,
// and convert TracedModuleForward nodes to CallMethod nodes into those
// methods.
//
// Then, we'll have a well-formed graph with proper method calls.
TORCH_API void FixupTraceScopeBlocks(
std::shared_ptr<Graph>& graph,
Module* self);
} // namespace jit
} // namespace torch