ac0fa4ecb6dcae129c50decaf586ad5b3ee77dd9
[charm.git] / tests / converse / commbench / broadcast.c
1 /*****************************************************************************
2  *  Benchmark to measure performance of CmiSyncBroadcast
3  *  
4  *  Does two types of benchmarking-
5  *
6  *  1. A flurry of Bcasts followed by a reduction
7  *
8  *  2. Singleton broadcast followed by reduction (clocks synchronized
9  *                                                across processors)
10  *
11  *  Author- Nikhil Jain
12  *  Date- Dec/26/2011
13  *
14  *****************************************************************************/
15
16 #include "converse.h"
17 #include "commbench.h"
18
19 typedef double* pdouble;
20
21 CpvStaticDeclare(int, numiter);
22 CpvStaticDeclare(int, nextidx);
23 CpvStaticDeclare(int, bcast_handler);
24 CpvStaticDeclare(int, bcast_reply);
25 CpvStaticDeclare(int, bcast_central);
26 CpvStaticDeclare(int, reduction_handler);
27 CpvStaticDeclare(int, sync_starter);
28 CpvStaticDeclare(int, sync_reply);
29 CpvStaticDeclare(double, starttime);
30 CpvStaticDeclare(double, lasttime);
31 CpvStaticDeclare(pdouble, timediff);
32 CpvStaticDeclare(int, currentPe);
33
34 static struct testdata {
35   int size;
36   int numiter;
37   double time;
38 } sizes[] = {
39   {4,       1024,      0.0},
40   {16,      1024,      0.0},
41   {64,      1024,      0.0},
42   {256,     1024,      0.0},
43   {1024,    1024,      0.0},
44   {4096,    1024,      0.0},
45   {16384,   1024,      0.0},
46   {65536,   1024,      0.0},
47   {262144,  1024,       0.0},
48   {1048576, 1024,        0.0},
49   {-1,      -1,        0.0},
50 };
51
52 typedef struct _timemsg {
53       char head[CmiMsgHeaderSizeBytes];
54       double time;
55       int srcpe;
56 } *ptimemsg;
57
58 typedef struct _timemsg timemsg;
59
60 static char *sync_outstr =
61 "[broadcast] (%s) %le seconds per %d bytes\n"
62 ;
63
64 static void * reduceMessage(int *size, void *data, void **remote, int count) 
65 {
66   return data;
67 }
68
69 static void print_results(char *func)
70 {
71   int i=0;
72
73   while(sizes[i].size != (-1)) {
74     CmiPrintf(sync_outstr, func, sizes[i].time/sizes[i].numiter, sizes[i].size);
75     i++;
76   }
77 }
78
79 static void bcast_handler(void *msg)
80 {
81   int idx = CpvAccess(nextidx);
82   void *red_msg;
83
84   CpvAccess(numiter)++;
85   if(CpvAccess(numiter)<sizes[idx].numiter) {
86     if(CmiMyPe() == 0) {
87       CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[idx].size, msg);
88       CmiFree(msg);
89     }else
90         CmiFree(msg);
91
92   } else {
93       CmiFree(msg);
94     red_msg = CmiAlloc(CmiMsgHeaderSizeBytes);
95     CmiSetHandler(red_msg, CpvAccess(reduction_handler));
96     CmiReduce(red_msg, CmiMsgHeaderSizeBytes, reduceMessage);
97     if(CmiMyPe() != 0) {
98       CpvAccess(nextidx) = idx + 1;
99       CpvAccess(numiter) = 0;
100     }
101   }
102 }
103
104 static void reduction_handler(void *msg) 
105 {
106   int i=0;
107   int idx = CpvAccess(nextidx);
108   EmptyMsg emsg;
109
110   sizes[idx].time = CmiWallTimer() - CpvAccess(starttime);
111   CmiFree(msg);
112   CpvAccess(numiter) = 0;
113   idx++;
114   if(sizes[idx].size == (-1)) {
115     print_results("Consecutive CmiSyncBroadcastAll");
116     CpvAccess(nextidx) = 0;
117     CpvAccess(numiter) = 0;
118     while(sizes[i].size != (-1)) {
119       sizes[i].time = 0;
120       i++;
121     }
122     CmiSetHandler(&emsg, CpvAccess(sync_reply));
123     CpvAccess(lasttime) = CmiWallTimer(); 
124     CmiSyncSend(CpvAccess(currentPe), sizeof(EmptyMsg), &emsg);
125     return;
126   } else {
127     CpvAccess(nextidx) = idx;
128     msg = CmiAlloc(CmiMsgHeaderSizeBytes+sizes[idx].size);
129     CmiSetHandler(msg, CpvAccess(bcast_handler));
130     CpvAccess(starttime) = CmiWallTimer();
131     CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[idx].size, msg);
132     CmiFree(msg);
133   }
134 }
135    
136 static void sync_starter(void *msg) 
137 {
138   EmptyMsg emsg;    
139   ptimemsg tmsg = (ptimemsg)msg;
140
141   double midTime = (CmiWallTimer() + CpvAccess(lasttime))/2;
142   CpvAccess(timediff)[CpvAccess(currentPe)] = midTime - tmsg->time;
143   CmiFree(msg);
144
145   CpvAccess(currentPe)++;
146   if(CpvAccess(currentPe) < CmiNumPes()) {
147     CmiSetHandler(&emsg, CpvAccess(sync_reply));
148     CpvAccess(lasttime) = CmiWallTimer(); 
149     CmiSyncSend(CpvAccess(currentPe), sizeof(EmptyMsg), &emsg);
150   } else {
151     msg = CmiAlloc(CmiMsgHeaderSizeBytes+sizes[0].size);
152     CmiSetHandler(msg, CpvAccess(bcast_reply));
153     CpvAccess(currentPe) = 0;
154     CpvAccess(starttime) = CmiWallTimer();
155     CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[0].size, msg);
156     CmiFree(msg);
157   }
158 }
159
160 static void sync_reply(void *msg) 
161 {
162   ptimemsg tmsg = (ptimemsg)CmiAlloc(sizeof(timemsg));
163   tmsg->time = CmiWallTimer();
164
165   CmiFree(msg);
166   CmiSetHandler(tmsg, CpvAccess(sync_starter));
167   CmiSyncSend(0, sizeof(timemsg), tmsg);
168   CmiFree(tmsg);
169 }
170  
171 static void bcast_reply(void *msg)
172 {
173   ptimemsg tmsg = (ptimemsg)CmiAlloc(sizeof(timemsg));
174   tmsg->time = CmiWallTimer();
175   tmsg->srcpe = CmiMyPe();
176   CmiFree(msg);
177   CmiSetHandler(tmsg, CpvAccess(bcast_central));
178   CmiSyncSend(0, sizeof(timemsg), tmsg);
179   CmiFree(tmsg);
180 }
181
182 static void bcast_central(void *msg)
183 {
184   EmptyMsg emsg;
185   ptimemsg tmsg = (ptimemsg)msg;
186   if(CpvAccess(currentPe) == 0) {
187     CpvAccess(lasttime) = tmsg->time - CpvAccess(starttime) + 
188                           CpvAccess(timediff)[tmsg->srcpe];
189   } else if((tmsg->time - CpvAccess(starttime) + 
190     CpvAccess(timediff)[tmsg->srcpe]) > CpvAccess(lasttime)) {
191     CpvAccess(lasttime) = tmsg->time - CpvAccess(starttime) +
192                           CpvAccess(timediff)[tmsg->srcpe];
193   }
194   CmiFree(msg);
195   CpvAccess(currentPe)++;
196   if(CpvAccess(currentPe) == CmiNumPes()) {
197     sizes[CpvAccess(nextidx)].time += CpvAccess(lasttime);
198     CpvAccess(numiter)++;
199     if(CpvAccess(numiter)<sizes[CpvAccess(nextidx)].numiter) {
200       msg = CmiAlloc(CmiMsgHeaderSizeBytes+sizes[CpvAccess(nextidx)].size);
201       CpvAccess(currentPe) = 0;
202       CmiSetHandler(msg, CpvAccess(bcast_reply));
203       CpvAccess(starttime) = CmiWallTimer();
204       CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[CpvAccess(nextidx)].size, msg);
205       CmiFree(msg);
206     } else {
207       CpvAccess(numiter) = 0;
208       CpvAccess(nextidx)++;
209       if(sizes[CpvAccess(nextidx)].size == (-1)) {
210         print_results("CmiSyncBroadcastAll");
211         CmiSetHandler(&emsg, CpvAccess(ack_handler));
212         CmiSyncSend(0, sizeof(EmptyMsg), &emsg);
213         return;
214       } else {
215         msg = CmiAlloc(CmiMsgHeaderSizeBytes+sizes[CpvAccess(nextidx)].size);
216         CpvAccess(currentPe) = 0;
217         CmiSetHandler(msg, CpvAccess(bcast_reply));
218         CpvAccess(starttime) = CmiWallTimer();
219         CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[CpvAccess(nextidx)].size, 
220                             msg);
221         CmiFree(msg);
222       }
223     }
224   }
225 }
226
227 void broadcast_init(void)
228 {
229   void *msg;
230
231   msg = CmiAlloc(CmiMsgHeaderSizeBytes+sizes[0].size);
232   CmiSetHandler(msg, CpvAccess(bcast_handler));
233   CpvAccess(starttime) = CmiWallTimer();
234   CmiSyncBroadcastAll(CmiMsgHeaderSizeBytes+sizes[0].size, msg);
235   CmiFree(msg);
236 }
237
238 void broadcast_moduleinit(void)
239 {
240   CpvInitialize(int, numiter);
241   CpvInitialize(int, nextidx);
242   CpvInitialize(double, starttime);
243   CpvInitialize(double, lasttime);
244   CpvInitialize(pdouble, timediff); 
245   CpvInitialize(int, currentPe);
246   CpvInitialize(int, bcast_handler);
247   CpvInitialize(int, bcast_reply);
248   CpvInitialize(int, bcast_central);
249   CpvInitialize(int, reduction_handler);
250   CpvInitialize(int, sync_starter);
251   CpvInitialize(int, sync_reply);
252   CpvAccess(numiter) = 0;
253   CpvAccess(nextidx) = 0;
254   CpvAccess(currentPe) = 0;
255   CpvAccess(timediff) = (pdouble)malloc(CmiNumPes()*sizeof(double));
256   CpvAccess(bcast_handler) = CmiRegisterHandler((CmiHandler)bcast_handler);
257   CpvAccess(bcast_reply) = CmiRegisterHandler((CmiHandler)bcast_reply);
258   CpvAccess(bcast_central) = CmiRegisterHandler((CmiHandler)bcast_central);
259   CpvAccess(reduction_handler) = CmiRegisterHandler((CmiHandler)reduction_handler);
260   CpvAccess(sync_starter) = CmiRegisterHandler((CmiHandler)sync_starter);
261   CpvAccess(sync_reply) = CmiRegisterHandler((CmiHandler)sync_reply);
262 }