-
Notifications
You must be signed in to change notification settings - Fork 1
/
send.py
35 lines (31 loc) · 860 Bytes
/
send.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import torch
import torch.distributed as dist
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
torch.cuda.set_device(0)
dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:8000', rank=1, world_size=2)
a = torch.ones((5,5)).cuda()
b = torch.ones((5,5)).cuda()
def f():
# c = a*b
h = dist.isend(b, 0)
if h:
h.wait()
else:
print("None!!!")
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(11):
f()
torch.cuda.current_stream().wait_stream(s)
torch.cuda.synchronize()
g = torch.cuda.CUDAGraph()
# Sets grads to None before capture, so backward() will create
# .grad attributes with allocations from the graph's private pool
with torch.cuda.graph(g):
f()
torch.cuda.synchronize()
b[:, :] = 114514
for _ in range(6):
g.replay()