Adding performance benchmarks for CmiReduce and Broadcast in commbench
[charm.git] / tests / converse / commbench / broadcast.c
1 /*****************************************************************************
2  *  Benchmark to measure performance of CmiSyncBroadcast
3  *  
4  *  Does two types of benchmarking-
5  *
6  *  1. A flurry of Bcasts followed by a reduction
7  *
8  *  2. Singleton broadcast followed by reduction (clocks synchronized
9  *                                                across processors)
10  *
11  *  Author- Nikhil Jain
12  *  Date- Dec/26/2011
13  *
14  *****************************************************************************/
15
16 #include "converse.h"
17 #include "commbench.h"
18
19 typedef double* pdouble;
20
21 CpvStaticDeclare(int, numiter);
22 CpvStaticDeclare(int, nextidx);
23 CpvStaticDeclare(int, bcast_handler);
24 CpvStaticDeclare(int, bcast_reply);
25 CpvStaticDeclare(int, bcast_central);
26 CpvStaticDeclare(int, reduction_handler);
27 CpvStaticDeclare(int, sync_starter);
28 CpvStaticDeclare(int, sync_reply);
29 CpvStaticDeclare(double, starttime);
30 CpvStaticDeclare(double, lasttime);
31 CpvStaticDeclare(pdouble, timediff);
32 CpvStaticDeclare(int, currentPe);
33
34 static struct testdata {
35   int size;
36   int numiter;
37   double time;
38 } sizes[] = {
39   {4,       1024,      0.0},
40   {16,      1024,      0.0},
41   {64,      1024,      0.0},
42   {256,     1024,      0.0},
43   {1024,    1024,      0.0},
44   {4096,    1024,      0.0},
45   {16384,   1024,      0.0},
46   {65536,   1024,      0.0},
47   {262144,  1024,      0.0},
48   {1048576, 1024,      0.0},
49   {-1,      -1,        0.0},
50 };
51
52 typedef struct _timemsg {
53       char head[CmiMsgHeaderSizeBytes];
54       double time;
55       int srcpe;
56 } *ptimemsg;
57
58 typedef struct _timemsg timemsg;
59
60 static char *sync_outstr =
61 "[broadcast] (%s) %le seconds per %d bytes\n"
62 ;
63
64 static void * reduceMessage(int *size, void *data, void **remote, int count) 
65 {
66   return data;
67 }
68
69 static void print_results(char *func)
70 {
71   int i=0;
72
73   while(sizes[i].size != (-1)) {
74     CmiPrintf(sync_outstr, func, sizes[i].time/sizes[i].numiter, sizes[i].size);
75     i++;
76   }
77 }
78
79 static void bcast_handler(void *msg)
80 {
81   int idx = CpvAccess(nextidx);
82   void *red_msg;
83
84   CpvAccess(numiter)++;
85   if(CpvAccess(numiter)<sizes[idx].numiter) {
86     if(CmiMyPe() == 0) {
87       CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[idx].size, msg);
88       CmiFree(msg);
89     }
90   } else {
91     red_msg = CmiAlloc(CmiMsgHeaderSizeBytes);
92     CmiSetHandler(red_msg, CpvAccess(reduction_handler));
93     CmiReduce(red_msg, CmiMsgHeaderSizeBytes, reduceMessage);
94     if(CmiMyPe() != 0) {
95       CpvAccess(nextidx) = idx + 1;
96       CpvAccess(numiter) = 0;
97     }
98   }
99 }
100
101 static void reduction_handler(void *msg) 
102 {
103   int i=0;
104   int idx = CpvAccess(nextidx);
105   EmptyMsg emsg;
106
107   sizes[idx].time = CmiWallTimer() - CpvAccess(starttime);
108   CmiFree(msg);
109   CpvAccess(numiter) = 0;
110   idx++;
111   if(sizes[idx].size == (-1)) {
112     print_results("Consecutive CmiSyncBroadcastAll");
113     CpvAccess(nextidx) = 0;
114     CpvAccess(numiter) = 0;
115     while(sizes[i].size != (-1)) {
116       sizes[i].time = 0;
117       i++;
118     }
119     CmiSetHandler(&emsg, CpvAccess(sync_reply));
120     CpvAccess(lasttime) = CmiWallTimer(); 
121     CmiSyncSend(CpvAccess(currentPe), sizeof(EmptyMsg), &emsg);
122     return;
123   } else {
124     CpvAccess(nextidx) = idx;
125     msg = CmiAlloc(CmiMsgHeaderSizeBytes+sizes[idx].size);
126     CmiSetHandler(msg, CpvAccess(bcast_handler));
127     CpvAccess(starttime) = CmiWallTimer();
128     CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[idx].size, msg);
129     CmiFree(msg);
130   }
131 }
132    
133 static void sync_starter(void *msg) 
134 {
135   EmptyMsg emsg;    
136   ptimemsg tmsg = (ptimemsg)msg;
137
138   double midTime = (CmiWallTimer() + CpvAccess(lasttime))/2;
139   CpvAccess(timediff)[CpvAccess(currentPe)] = midTime - tmsg->time;
140   CmiFree(msg);
141
142   CpvAccess(currentPe)++;
143   if(CpvAccess(currentPe) < CmiNumPes()) {
144     CmiSetHandler(&emsg, CpvAccess(sync_reply));
145     CpvAccess(lasttime) = CmiWallTimer(); 
146     CmiSyncSend(CpvAccess(currentPe), sizeof(EmptyMsg), &emsg);
147   } else {
148     msg = CmiAlloc(CmiMsgHeaderSizeBytes+sizes[0].size);
149     CmiSetHandler(msg, CpvAccess(bcast_reply));
150     CpvAccess(currentPe) = 0;
151     CpvAccess(starttime) = CmiWallTimer();
152     CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[0].size, msg);
153     CmiFree(msg);
154   }
155 }
156
157 static void sync_reply(void *msg) 
158 {
159   ptimemsg tmsg = (ptimemsg)CmiAlloc(sizeof(timemsg));
160   tmsg->time = CmiWallTimer();
161
162   CmiFree(msg);
163   CmiSetHandler(tmsg, CpvAccess(sync_starter));
164   CmiSyncSend(0, sizeof(timemsg), tmsg);
165   CmiFree(tmsg);
166 }
167  
168 static void bcast_reply(void *msg)
169 {
170   ptimemsg tmsg = (ptimemsg)CmiAlloc(sizeof(timemsg));
171   tmsg->time = CmiWallTimer();
172   tmsg->srcpe = CmiMyPe();
173   CmiFree(msg);
174   CmiSetHandler(tmsg, CpvAccess(bcast_central));
175   CmiSyncSend(0, sizeof(timemsg), tmsg);
176   CmiFree(tmsg);
177 }
178
179 static void bcast_central(void *msg)
180 {
181   EmptyMsg emsg;
182   ptimemsg tmsg = (ptimemsg)msg;
183   if(CpvAccess(currentPe) == 0) {
184     CpvAccess(lasttime) = tmsg->time - CpvAccess(starttime) + 
185                           CpvAccess(timediff)[tmsg->srcpe];
186   } else if((tmsg->time - CpvAccess(starttime) + 
187     CpvAccess(timediff)[tmsg->srcpe]) > CpvAccess(lasttime)) {
188     CpvAccess(lasttime) = tmsg->time - CpvAccess(starttime) +
189                           CpvAccess(timediff)[tmsg->srcpe];
190   }
191   CmiFree(msg);
192   CpvAccess(currentPe)++;
193   if(CpvAccess(currentPe) == CmiNumPes()) {
194     sizes[CpvAccess(nextidx)].time += CpvAccess(lasttime);
195     CpvAccess(numiter)++;
196     if(CpvAccess(numiter)<sizes[CpvAccess(nextidx)].numiter) {
197       msg = CmiAlloc(CmiMsgHeaderSizeBytes+sizes[CpvAccess(nextidx)].size);
198       CpvAccess(currentPe) = 0;
199       CmiSetHandler(msg, CpvAccess(bcast_reply));
200       CpvAccess(starttime) = CmiWallTimer();
201       CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[CpvAccess(nextidx)].size, msg);
202       CmiFree(msg);
203     } else {
204       CpvAccess(numiter) = 0;
205       CpvAccess(nextidx)++;
206       if(sizes[CpvAccess(nextidx)].size == (-1)) {
207         print_results("CmiSyncBroadcastAll");
208         CmiSetHandler(&emsg, CpvAccess(ack_handler));
209         CmiSyncSend(0, sizeof(EmptyMsg), &emsg);
210         return;
211       } else {
212         msg = CmiAlloc(CmiMsgHeaderSizeBytes+sizes[CpvAccess(nextidx)].size);
213         CpvAccess(currentPe) = 0;
214         CmiSetHandler(msg, CpvAccess(bcast_reply));
215         CpvAccess(starttime) = CmiWallTimer();
216         CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[CpvAccess(nextidx)].size, 
217                             msg);
218         CmiFree(msg);
219       }
220     }
221   }
222 }
223
224 void broadcast_init(void)
225 {
226   void *msg;
227
228   msg = CmiAlloc(CmiMsgHeaderSizeBytes+sizes[0].size);
229   CmiSetHandler(msg, CpvAccess(bcast_handler));
230   CpvAccess(starttime) = CmiWallTimer();
231   CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[0].size, msg);
232   CmiFree(msg);
233 }
234
235 void broadcast_moduleinit(void)
236 {
237   CpvInitialize(int, numiter);
238   CpvInitialize(int, nextidx);
239   CpvInitialize(double, starttime);
240   CpvInitialize(double, lasttime);
241   CpvInitialize(pdouble, timediff); 
242   CpvInitialize(int, currentPe);
243   CpvInitialize(int, bcast_handler);
244   CpvInitialize(int, bcast_reply);
245   CpvInitialize(int, bcast_central);
246   CpvInitialize(int, reduction_handler);
247   CpvInitialize(int, sync_starter);
248   CpvInitialize(int, sync_reply);
249   CpvAccess(numiter) = 0;
250   CpvAccess(nextidx) = 0;
251   CpvAccess(currentPe) = 0;
252   CpvAccess(timediff) = (pdouble)malloc(CmiNumPes()*sizeof(double));
253   CpvAccess(bcast_handler) = CmiRegisterHandler((CmiHandler)bcast_handler);
254   CpvAccess(bcast_reply) = CmiRegisterHandler((CmiHandler)bcast_reply);
255   CpvAccess(bcast_central) = CmiRegisterHandler((CmiHandler)bcast_central);
256   CpvAccess(reduction_handler) = CmiRegisterHandler((CmiHandler)reduction_handler);
257   CpvAccess(sync_starter) = CmiRegisterHandler((CmiHandler)sync_starter);
258   CpvAccess(sync_reply) = CmiRegisterHandler((CmiHandler)sync_reply);
259 }