forked from dmlc/mshadow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
local_sum-inl.h
119 lines (115 loc) · 3.73 KB
/
local_sum-inl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
// This is an example demonstrating the usage of mshadow ps
#include <cstdio>
// use openmp to launch multiple threads
#include <omp.h>
#include <mshadow/tensor.h>
#include <mshadow-ps/mshadow_ps.h>
// simple util to print result
void Print_(mshadow::Tensor<mshadow::cpu, 2, float> ts) {
for (mshadow::index_t i = 0; i < ts.size(0); ++i) {
for (mshadow::index_t j = 0; j < ts.size(1); ++j) {
printf("%g ", ts[i][j]);
}
printf("\n");
}
}
template<typename xpu>
inline void Print(mshadow::Tensor<xpu, 2, float> ts) {
mshadow::TensorContainer<mshadow::cpu, 2, float> tmp;
tmp.Resize(ts.shape_);
mshadow::Copy(tmp, ts);
Print_(tmp);
}
// this function is runed by specific thread
template<typename xpu>
inline void RunWorkerThread(int devid,
mshadow::ps::ISharedModel<xpu, float> *ps) {
// initialize tensor engine
mshadow::InitTensorEngine<xpu>(devid);
mshadow::Stream<xpu> *stream = mshadow::NewStream<xpu>();
// allocate tensor on xpu
mshadow::TensorContainer<xpu, 2> data(mshadow::Shape2(2, 3));
// set the computation stream to the new allocated stream
// this will make subsequent computation whose target is data
// to use the stream, stream is needed for async execution in GPU
data.set_stream(stream);
// assume these operations sets the content of dataient
data[0] = 1.0f;
data[1] = devid + data[0];
printf("dev%d: before sync, data:\n", devid);
// use print to show result, do not call
// print normally since Copy will block
Print(data);
printf("====================\n");
// intiaialize the key, register the shape on parameter server
ps->InitKey(data[0].shape_, 0, devid);
ps->InitKey(data[1].shape_, 1, devid);
// push data[0] out, for update, or aggregation
// 0 is the key of the data, devid is the current device id
ps->Push(data[0], 0, devid);
// pull request is used to request the data to be copied back
// once computation is done
ps->PullReq(data[0], 0, devid);
// computation can be done here..
// the pull request handler will be overlapped with
// similar as previous call
ps->Push(data[1], 1, devid);
ps->PullReq(data[1], 1, devid);
// more computation can be done here...
// the computation will be overlapped
// PullWait will block until these request finishes
ps->PullWait(0, devid);
ps->PullWait(1, devid);
printf("dev%d: after sync, data:\n", devid);
// use print to show result, do not call
// print normally since Copy will block
Print(data);
printf("====================\n");
mshadow::DeleteStream(stream);
mshadow::ShutdownTensorEngine<xpu>();
}
namespace mshadow {
namespace ps {
// model updater is used when update is happening on server side
// if we only use parameter server for sum aggregation
// this is not needed, but we must declare this function to return NULL
template<>
IModelUpdater<float> *CreateModelUpdater(void) {
return NULL;
}
}
}
template<typename xpu>
inline int Run(int argc, char *argv[]) {
if (argc < 2) {
printf("Usage: device list\n"\
"\tfor CPU the device list can be arbitrary\n"\
"\tfor GPU the device list need to be actual device index\n");
return 0;
}
#if MSHADOW_RABIT_PS
rabit::Init(argc, argv);
#endif
// list of device ids
std::vector<int> devs;
// initialization
for (int i = 1; i < argc; ++i) {
// record the device id
devs.push_back(atoi(argv[i]));
}
mshadow::ps::ISharedModel<xpu, float>
*ps = mshadow::ps::CreateSharedModel<xpu, float>("local");
// intiaialize the ps
ps->Init(devs);
// use openmp to launch #devs threads
#pragma omp parallel num_threads(devs.size())
{
int tid = omp_get_thread_num();
RunWorkerThread<xpu>(devs[tid], ps);
}
delete ps;
#if MSHADOW_RABIT_PS
rabit::Finalize();
#endif
return 0;
}