forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_module_api.cpp
470 lines (405 loc) · 14.8 KB
/
test_module_api.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
#include <gtest/gtest.h>
#include <test/cpp/jit/test_utils.h>
#include <ATen/core/qualified_name.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/torch.h>
namespace torch {
namespace jit {
static constexpr c10::string_view moduleInterfaceSrc = R"JIT(
class OneInterface(ModuleInterface):
def one(self, x: Tensor, y: Tensor) -> Tensor:
pass
)JIT";
static const std::vector<std::string> subModuleMethodsSrc = {R"JIT(
def one(self, x: Tensor, y: Tensor) -> Tensor:
return self.attr * x + y + 1
def forward(self, x: Tensor) -> Tensor:
return self.attr + x
)JIT"};
static const std::string parentForward = R"JIT(
def forward(self, x: Tensor) -> Tensor:
return self.subMod1.one(x, x) + self.subMod2.one(x, x)
)JIT";
static void import_libs(
std::shared_ptr<CompilationUnit> cu,
const std::string& class_name,
const std::shared_ptr<Source>& src,
const std::vector<at::IValue>& tensor_table) {
SourceImporter si(
cu,
&tensor_table,
[&](const std::string& name) -> std::shared_ptr<Source> { return src; },
/*version=*/2);
si.loadType(QualifiedName(class_name));
}
TEST(ModuleAPITest, MethodRunAsync) {
// Module m("m");
// m.define(R"(
// def forward(self):
// r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
// r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
// return r1.wait() + r2.wait()
// )");
std::string filePath(__FILE__);
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
// borrow model file from TEST(GraphExecutorTest, runAsync_executor)
testModelFile.append("test_interpreter_async.pt");
auto m = load(testModelFile);
auto counter = 0;
std::mutex mtx;
auto launcher = [&](std::function<void()> f) {
mtx.lock();
++counter;
mtx.unlock();
at::launch(std::move(f));
};
auto method = m.get_method("forward");
std::vector<IValue> stack;
auto kwargs = std::unordered_map<std::string, at::IValue>();
auto future = method.run_async(stack, kwargs, launcher);
future->wait();
// expect 2 forks and 2 wait callbacks being excuted on provided taskLauncher
// but ivalue::Future would be marked completed and release wait before
// finishing all callbacks
ASSERT_GE(counter, 2);
}
TEST(ModuleAPITest, Clone) {
auto cu = std::make_shared<CompilationUnit>();
// creating child module
auto child = ClassType::create("child", cu, true);
auto attr_name = "attr";
child->addAttribute(attr_name, IntType::get());
Module c1(cu, child);
auto v1 = IValue(2);
c1.register_attribute(attr_name, IntType::get(), v1, false);
Module c2(cu, child);
auto v2 = IValue(3);
c2.register_attribute(attr_name, IntType::get(), v2, false);
// attach two child module instance to parent that shares
// ClassType
auto parent = ClassType::create("parent", cu, true);
Module p(cu, parent);
p.register_attribute("c1", c1.type(), c1._ivalue(), false);
p.register_attribute("c2", c2.type(), c2._ivalue(), false);
// clone parent
Module p2 = p.clone();
// check the two child module has the same ClassType
ASSERT_EQ(p2.attr("c1").type(), p2.attr("c2").type());
// but different instances
ASSERT_EQ(Module(p2.attr("c1").toObject()).attr(attr_name).toInt(), 2);
ASSERT_EQ(Module(p2.attr("c2").toObject()).attr(attr_name).toInt(), 3);
}
TEST(ModuleAPITest, CloneWithModuleInterface) {
auto cu = std::make_shared<CompilationUnit>();
// define a initial module with two submods share same interface
Module parentMod("parentMod", cu);
Module subMod1("subMod1", cu);
Module subMod2("subMod2", cu);
std::vector<at::IValue> constantTable;
import_libs(
cu,
"__torch__.OneInterface",
std::make_shared<Source>(moduleInterfaceSrc),
constantTable);
auto v1 = IValue(2);
subMod1.register_attribute("attr", IntType::get(), v1, false);
auto v2 = IValue(4);
subMod2.register_attribute("attr", IntType::get(), v2, false);
for (const std::string& method : subModuleMethodsSrc) {
subMod1.define(method, nativeResolver());
subMod2.define(method, nativeResolver());
}
parentMod.register_attribute(
"subMod1",
cu->get_interface("__torch__.OneInterface"),
subMod1._ivalue());
parentMod.register_attribute(
"subMod2",
cu->get_interface("__torch__.OneInterface"),
subMod2._ivalue());
parentMod.define(parentForward, nativeResolver());
Module clonedMod = parentMod.clone();
// clone will copy both type and data, therefore we'll have a
// different type
ASSERT_NE(clonedMod.type(), parentMod.type());
}
TEST(ModuleAPITest, Copy) {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true);
auto attr_name = "attr";
cls->addAttribute(attr_name, IntType::get());
Module m(cu, cls);
auto v = IValue(2);
m.register_attribute(attr_name, IntType::get(), v, false);
Module m2 = m.clone();
Module m3 = m.copy();
// Make sure copy works
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
ASSERT_EQ(m3.attr(attr_name).toInt(), 2);
// clone will copy both type and data, therefore we'll have a
// different type
ASSERT_NE(m.type(), m2.type());
// copy only copies data, type is shared
ASSERT_EQ(m.type(), m3.type());
// change value of copied instance
m3.register_attribute(attr_name, IntType::get(), IValue(3), false);
// Verify value of original instance doesn't change
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
ASSERT_EQ(m3.attr(attr_name).toInt(), 3);
}
TEST(ModuleAPITest, DeepCopy) {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true);
auto str_attr = "str_attr";
auto int_attr = "int_attr";
auto tensor_attr = "tensor_attr";
auto tensor_list_attr = "tensor_list_attr";
cls->addAttribute(int_attr, IntType::get());
cls->addAttribute(str_attr, StringType::get());
cls->addAttribute(tensor_attr, TensorType::get());
cls->addAttribute(tensor_list_attr, ListType::ofTensors());
Module m(cu, cls);
c10::List<at::Tensor> list({at::rand(5), at::rand(5)});
m.setattr(int_attr, IValue(2));
m.setattr(str_attr, IValue("str"));
m.setattr(tensor_attr, at::randn(5));
m.setattr(tensor_list_attr, list);
Module m2 = m.deepcopy();
Module m3 = m.copy();
// Make sure copy works
ASSERT_EQ(m2.attr(int_attr).toInt(), 2);
ASSERT_EQ(m3.attr(int_attr).toInt(), 2);
// Test overlaps
ASSERT_TRUE(!IValue(m2._ivalue()).overlaps(IValue(m._ivalue())));
ASSERT_TRUE(IValue(m3._ivalue()).overlaps(IValue(m._ivalue())));
// Both deepcopy and copy will preserve the type
ASSERT_EQ(m.type(), m2.type());
ASSERT_EQ(m.type(), m3.type());
// change int value of copied instances
m2.setattr(int_attr, IValue(3));
m3.setattr(int_attr, IValue(4));
// Verify value of original instance doesn't change
ASSERT_EQ(m.attr(int_attr).toInt(), 2);
ASSERT_EQ(m2.attr(int_attr).toInt(), 3);
ASSERT_EQ(m3.attr(int_attr).toInt(), 4);
// change Tensor value of copied instances
at::Tensor t1 = m.attr(tensor_attr).toTensor();
at::Tensor t2 =
m2.attr(tensor_attr).toTensor(); // deepcopy will copy the Tensor
at::Tensor t3 =
m3.attr(tensor_attr).toTensor(); // copy will not copy the Tensor
// check copy works
ASSERT_TRUE(t1.equal(t2));
ASSERT_TRUE(t1.equal(t3));
// zero out t1
t1.zero_();
// check that t2 is not affected because it is a deep copy
ASSERT_TRUE(!t1.equal(t2));
// check that t3 is the same as t1 since it is a shallow copy
ASSERT_TRUE(t1.equal(t3));
}
TEST(ModuleAPITest, DeepCopyString) {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true);
auto attr1 = "attr1";
cls->addAttribute(attr1, StringType::get());
std::string str = "str";
Module m(cu, cls);
m.setattr(attr1, str);
auto copied = m.deepcopy();
auto original_str = str;
ASSERT_EQ(copied.attr(attr1).toStringRef(), original_str);
// check string mutation is not reflected in the copied module
str += "str";
ASSERT_EQ(copied.attr(attr1).toStringRef(), original_str);
}
TEST(ModuleAPITest, DeepCopyEnum) {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true);
auto enum_attr = "enum_attr";
auto int_enum_type = EnumType::create(
"enum_class",
IntType::get(),
{{"enum_name_1", 1}, {"enum_name_2", 2}},
cu);
cls->addAttribute(enum_attr, int_enum_type);
Module m(cu, cls);
m.setattr(
enum_attr,
IValue(c10::make_intrusive<ivalue::EnumHolder>(
int_enum_type, "enum_name_1", 1)));
Module m2 = m.deepcopy();
// Make sure deepcopy works
c10::ivalue::EnumHolder* m2_holder = m2.attr(enum_attr).toEnumHolder().get();
ASSERT_EQ(m2_holder->value().toInt(), 1);
ASSERT_EQ(m2_holder->name(), "enum_name_1");
ASSERT_EQ(m2_holder->type(), int_enum_type);
// Test overlaps
ASSERT_TRUE(!IValue(m2._ivalue()).overlaps(IValue(m._ivalue())));
// Deepcopy will preserve the type
ASSERT_EQ(m.type(), m2.type());
// Change original, should not affect deepcopy
m.setattr(
enum_attr,
IValue(c10::make_intrusive<ivalue::EnumHolder>(
int_enum_type, "enum_name_2", 2)));
ASSERT_NE(
m.attr(enum_attr).toEnumHolder().get()->value().toInt(),
m2.attr(enum_attr).toEnumHolder().get()->value().toInt());
}
TEST(ModuleAPITest, DeepCopyPreservesAliasing) {
// check deepcopy preserves aliasing
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true);
auto attr1 = "attr1";
auto attr2 = "attr2";
auto attr3 = "attr3";
auto attr4 = "attr4";
cls->addAttribute(attr1, ListType::ofTensors());
cls->addAttribute(attr2, ListType::ofTensors());
cls->addAttribute(attr3, TensorType::get());
cls->addAttribute(attr4, TensorType::get());
Module m(cu, cls);
auto t1 = at::rand(5);
auto t2 = at::rand(5);
auto t3 = at::rand(5);
auto t4 = at::rand({5, 2});
c10::List<at::Tensor> list1({t1, t2});
c10::List<at::Tensor> list2({t1, t3});
// first element of attr1 and attr2 are aliased
m.setattr(attr1, list1);
m.setattr(attr2, list2);
m.setattr(attr3, t4);
m.setattr(attr4, t4.view(-1));
auto copied = m.deepcopy();
// test tensor aliasing
auto copied_attr1_t1 = copied.attr(attr1).toList().get(0);
auto copied_attr2_t1 = copied.attr(attr2).toList().get(0);
ASSERT_TRUE(copied_attr1_t1.isAliasOf(copied_attr2_t1));
// test aliasing from view
auto copied_attr3 = copied.attr(attr3);
auto copied_attr4 = copied.attr(attr3);
ASSERT_TRUE(copied_attr3.isAliasOf(copied_attr4));
}
TEST(ModuleAPITest, Constants) {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true);
auto attr_name = "attr";
auto const_name = "const";
cls->addAttribute(attr_name, IntType::get());
cls->addConstant(const_name, IValue(3));
Module m(cu, cls);
auto v = IValue(2);
m.register_attribute(attr_name, IntType::get(), v, false);
ASSERT_TRUE(m.hasattr(attr_name));
ASSERT_TRUE(m.hasattr(const_name));
ASSERT_EQ(m.attr(attr_name).toInt(), 2);
ASSERT_EQ(m.attr(const_name).toInt(), 3);
}
TEST(ModuleAPITest, Parameters) {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true);
Module m(cu, cls);
// Tensor parameter
m.register_parameter(
"tensor_param", at::empty({3}, at::kFloat), /* is_buffer */ false);
// None parameter
m.register_attribute(
"none_param", NoneType::get(), IValue(), /* is_param */ true);
m.register_attribute(
"none_param2", NoneType::get(), IValue(), /* is_param */ true);
auto param_list = m.parameters();
ASSERT_EQ(param_list.size(), 1);
ASSERT_TRUE(m.hasattr("tensor_param"));
ASSERT_TRUE(m.hasattr("none_param"));
ASSERT_TRUE(m.hasattr("none_param2"));
}
TEST(ModuleAPITest, Define) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def add_it(self, x, b : int = 4):
return self.foo + x + b
)");
auto result = m.run_method("add_it", torch::ones({}));
AT_ASSERT(result.toTensor().item<float>() == 6);
}
TEST(ModuleAPITest, Freezing) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def forward(self, x, b : int = 4):
return self.foo + x + b
)");
m.eval();
auto forward_g = m.get_method("forward").graph();
testing::FileCheck().check("GetAttr")->run(*forward_g);
// Removal of GetAttr is done by freezing
auto frozen_mod = torch::jit::freeze(m);
forward_g = frozen_mod.get_method("forward").graph();
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
// If no training mode is set, the module is NOT frozen by OFI
auto frozen_mod2 = torch::jit::optimize_for_inference(m);
forward_g = frozen_mod2.get_method("forward").graph();
testing::FileCheck().check("GetAttr")->run(*forward_g);
}
TEST(ModuleAPITest, OfiFreezesTraining) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def forward(self, x, b : int = 4):
return self.foo + x + b
)");
m.register_attribute("training", BoolType::get(), true);
m.eval();
// Before freezing, we have a GetAttr check
auto forward_g = m.get_method("forward").graph();
testing::FileCheck().check("GetAttr")->run(*forward_g);
// Demonstrate that freezing happens when OFI is called
// Removal of GetAttr is done by freezing, but only when training
// attribute is set
auto frozen_mod = torch::jit::optimize_for_inference(m);
forward_g = frozen_mod.get_method("forward").graph();
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
}
TEST(ModuleAPITest, OfiFreezesNoForward) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def bar(self, x, b : int = 4):
return self.foo + x + b
)");
m.eval();
// OFI is called without the presence of forward methods
auto frozen_mod =
torch::jit::optimize_for_inference(m, std::vector<std::string>{"bar"});
ASSERT_EQ(
m.run_method("bar", torch::ones({})).toTensor().item<float>(),
frozen_mod.run_method("bar", torch::ones({})).toTensor().item<float>());
}
TEST(ModuleAPITest, To_CUDA) {
Module m("test");
{
// test cuda to cpu for params and buffers
m.register_parameter("foo", torch::ones({}, at::kCUDA), false);
m.register_buffer("bar", torch::ones({}, at::kCUDA));
m.to(at::kCUDA);
m.to(at::kCPU);
AT_ASSERT(m.attr("foo").toTensor().device().is_cpu());
AT_ASSERT(m.attr("bar").toTensor().device().is_cpu());
}
{
// test cpu to cuda for params and buffers
m.register_parameter("foo", torch::ones({}), false);
m.register_buffer("bar", torch::ones({}));
m.to(at::kCUDA);
AT_ASSERT(m.attr("foo").toTensor().device().is_cuda());
AT_ASSERT(m.attr("bar").toTensor().device().is_cuda());
}
}
} // namespace jit
} // namespace torch