fixed a typo in getting processor count.
[charm.git] / src / conv-ldb / cldb.workstealing.c
1 #include <stdlib.h>
2
3 #include "converse.h"
4 #include "cldb.workstealing.h"
5 #include "queueing.h"
6 #include "cldb.h"
7
8 #define IDLE_IMMEDIATE          0
9 #define TRACE_USEREVENTS        0
10
11 #define PERIOD 10                /* default: 30 */
12 #define MSGDELAY 10
13 #define MAXOVERLOAD 1
14
15 #define LOADTHRESH       3
16
17
18 typedef struct CldProcInfo_s {
19   int    balanceEvt;            /* user event for balancing */
20   int    idleEvt;               /* user event for idle balancing */
21   int    idleprocEvt;           /* user event for processing idle req */
22 } *CldProcInfo;
23
24 int _stealonly1 = 0;
25 int workstealingproactive = 0;
26
27 CpvStaticDeclare(CldProcInfo, CldData);
28 CpvStaticDeclare(int, CldAskLoadHandlerIndex);
29 CpvStaticDeclare(int, CldAckNoTaskHandlerIndex);
30 CpvStaticDeclare(int, isStealing);
31
32
33 char *CldGetStrategy(void)
34 {
35   return "work stealing";
36 }
37
38
39 static void StealLoad()
40 {
41   int i;
42   double startT;
43   requestmsg msg;
44   int myload;
45   int  victim;
46   int mype;
47   int numpes;
48
49   /* CcdRaiseCondition(CcdUSER); */
50
51   if (CpvAccess(isStealing)) return;    /* already stealing, return */
52   CpvAccess(isStealing) = 1;
53
54   myload = CldLoad();
55
56   mype = CmiMyPe();
57   msg.from_pe = mype;
58   numpes = CmiNumPes();
59   do{
60       victim = (((CrnRand()+mype)&0x7FFFFFFF)%numpes);
61   }while(victim == mype);
62 if (mype == 2) CmiPrintf("steal from %d\n", victim);
63
64   CmiSetHandler(&msg, CpvAccess(CldAskLoadHandlerIndex));
65 #if IDLE_IMMEDIATE
66   /* fixme */
67   CmiBecomeImmediate(&msg);
68 #endif
69   msg.to_rank = CmiRankOf(victim);
70   CmiSyncSend(victim, sizeof(requestmsg),(char *)&msg);
71   
72 #if CMK_TRACE_ENABLED && TRACE_USEREVENTS
73   traceUserBracketEvent(cldData->idleEvt, now, CmiWallTimer());
74 #endif
75 }
76
77 void LoadNotifyFn(int l)
78 {
79     if(workstealingproactive)
80     {
81         if(CldLoad() < 3)
82             StealLoad();
83     }
84 }
85 /* since I am idle, ask for work from neighbors */
86
87 static void CldBeginIdle(void *dummy)
88 {
89     StealLoad();
90
91 }
92 /* immediate message handler, work at node level */
93 /* send some work to requested proc */
94 static void CldAskLoadHandler(requestmsg *msg)
95 {
96   int receiver, rank, recvIdx, i;
97   int myload = CldLoad();
98
99   int sendLoad;
100   sendLoad = myload / 2; 
101   receiver = msg->from_pe;
102   /* only give you work if I have more than 1 */
103   if (myload>LOADTHRESH) {
104       if(_stealonly1) sendLoad = 1;
105       rank = CmiMyRank();
106       if (msg->to_rank != -1) rank = msg->to_rank;
107       CldMultipleSend(receiver, sendLoad, rank, 0);
108   }else
109   {
110       msg->from_pe = CmiMyPe();
111       msg->to_rank = CmiMyRank();
112
113       /* CcdRaiseCondition(CcdUSER); */
114
115       CmiSetHandler(msg, CpvAccess(CldAckNoTaskHandlerIndex));
116       CmiSyncSendAndFree(receiver, sizeof(requestmsg),(char *)msg);
117     /* send ack indicating there is no task */
118   }
119 }
120
121 void  CldAckNoTaskHandler(requestmsg *msg)
122 {
123   int victim; 
124   int notaskpe = msg->from_pe;
125   int mype = CmiMyPe();
126
127   /* CcdRaiseCondition(CcdUSER); */
128
129   if (CmiNumPes()==2) victim = 2-mype;
130   else
131   do{
132       /*victim = (((CrnRand()+notaskpe)&0x7FFFFFFF)%CmiNumPes());*/
133       victim = (((CrnRand())&0x7FFFFFFF)%CmiNumPes());
134   }while(victim == mype || victim == notaskpe);
135
136   /* reuse msg */
137   msg->to_rank = CmiRankOf(victim);
138   msg->from_pe = mype;
139   CmiSetHandler(msg, CpvAccess(CldAskLoadHandlerIndex));
140   CmiSyncSendAndFree(victim, sizeof(requestmsg),(char *)msg);
141
142   CpvAccess(isStealing) = 1;
143 }
144
145 void CldHandler(void *msg)
146 {
147   CldInfoFn ifn; CldPackFn pfn;
148   int len, queueing, priobits; unsigned int *prioptr;
149   
150   CldRestoreHandler(msg);
151   ifn = (CldInfoFn)CmiHandlerToFunction(CmiGetInfo(msg));
152   ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
153   CsdEnqueueGeneral(msg, queueing, priobits, prioptr);
154 }
155
156 void CldBalanceHandler(void *msg)
157 {
158   CldRestoreHandler(msg);
159   CldPutToken(msg);
160   CpvAccess(isStealing) = 0;
161 }
162
163 void CldEnqueueGroup(CmiGroup grp, void *msg, int infofn)
164 {
165   int len, queueing, priobits,i; unsigned int *prioptr;
166   CldInfoFn ifn = (CldInfoFn)CmiHandlerToFunction(infofn);
167   CldPackFn pfn;
168   ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
169   if (pfn) {
170     pfn(&msg);
171     ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
172   }
173   CldSwitchHandler(msg, CpvAccess(CldHandlerIndex));
174   CmiSetInfo(msg,infofn);
175
176   CmiSyncMulticastAndFree(grp, len, msg);
177 }
178
179 void CldEnqueueMulti(int npes, int *pes, void *msg, int infofn)
180 {
181   int len, queueing, priobits,i; unsigned int *prioptr;
182   CldInfoFn ifn = (CldInfoFn)CmiHandlerToFunction(infofn);
183   CldPackFn pfn;
184   ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
185   if (pfn) {
186     pfn(&msg);
187     ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
188   }
189   CldSwitchHandler(msg, CpvAccess(CldHandlerIndex));
190   CmiSetInfo(msg,infofn);
191   CmiSyncListSendAndFree(npes, pes, len, msg);
192 }
193
194 void CldEnqueue(int pe, void *msg, int infofn)
195 {
196   int len, queueing, priobits, avg; unsigned int *prioptr;
197   CldInfoFn ifn = (CldInfoFn)CmiHandlerToFunction(infofn);
198   CldPackFn pfn;
199
200   if ((pe == CLD_ANYWHERE) && (CmiNumPes() > 1)) {
201       pe = CmiMyPe();
202     /* always pack the message because the message may be move away
203        to a different processor later by CldGetToken() */
204     ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
205     if (pfn && CmiNumNodes()>1) {
206        pfn(&msg);
207        ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
208     }
209     ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
210     CmiSetInfo(msg,infofn);
211     CldPutToken(msg);
212   } 
213   else if ((pe == CmiMyPe()) || (CmiNumPes() == 1)) {
214     ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
215     CsdEnqueueGeneral(msg, queueing, priobits, prioptr);
216   }
217   else {
218     ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
219     if (pfn && CmiNodeOf(pe) != CmiMyNode()) {
220       pfn(&msg);
221       ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
222     }
223     CldSwitchHandler(msg, CpvAccess(CldHandlerIndex));
224     CmiSetInfo(msg,infofn);
225     if (pe==CLD_BROADCAST) 
226       CmiSyncBroadcastAndFree(len, msg);
227     else if (pe==CLD_BROADCAST_ALL)
228       CmiSyncBroadcastAllAndFree(len, msg);
229     else CmiSyncSendAndFree(pe, len, msg);
230   }
231 }
232
233 void CldNodeEnqueue(int node, void *msg, int infofn)
234 {
235   int len, queueing, priobits, pe, avg; unsigned int *prioptr;
236   CldInfoFn ifn = (CldInfoFn)CmiHandlerToFunction(infofn);
237   CldPackFn pfn;
238   if ((node == CLD_ANYWHERE) && (CmiNumPes() > 1)) {
239       pe = CmiMyPe();
240       node = CmiNodeOf(pe);
241       ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
242       CsdNodeEnqueueGeneral(msg, queueing, priobits, prioptr);
243   }
244   else if ((node == CmiMyNode()) || (CmiNumPes() == 1)) {
245     ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
246     CsdNodeEnqueueGeneral(msg, queueing, priobits, prioptr);
247   } 
248   else {
249     ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
250     if (pfn) {
251         pfn(&msg);
252         ifn(msg, &pfn, &len, &queueing, &priobits, &prioptr);
253     }
254     CldSwitchHandler(msg, CpvAccess(CldHandlerIndex));
255     CmiSetInfo(msg,infofn);
256     if (node==CLD_BROADCAST) { CmiSyncNodeBroadcastAndFree(len, msg); }
257     else if (node==CLD_BROADCAST_ALL){CmiSyncNodeBroadcastAllAndFree(len,msg);}
258     else CmiSyncNodeSendAndFree(node, len, msg);
259   }
260 }
261
262
263 void CldGraphModuleInit(char **argv)
264 {
265   CpvInitialize(CldProcInfo, CldData);
266   CpvInitialize(int, CldAskLoadHandlerIndex);
267   CpvInitialize(int, CldAckNoTaskHandlerIndex);
268   CpvInitialize(int, CldBalanceHandlerIndex);
269
270   CpvAccess(CldData) = (CldProcInfo)CmiAlloc(sizeof(struct CldProcInfo_s));
271 #if CMK_TRACE_ENABLED
272   CpvAccess(CldData)->balanceEvt = traceRegisterUserEvent("CldBalance", -1);
273   CpvAccess(CldData)->idleEvt = traceRegisterUserEvent("CldBalanceIdle", -1);
274   CpvAccess(CldData)->idleprocEvt = traceRegisterUserEvent("CldBalanceProcIdle", -1);
275 #endif
276
277   CpvAccess(CldBalanceHandlerIndex) = 
278     CmiRegisterHandler(CldBalanceHandler);
279   CpvAccess(CldAskLoadHandlerIndex) = 
280     CmiRegisterHandler((CmiHandler)CldAskLoadHandler);
281   
282   CpvAccess(CldAckNoTaskHandlerIndex) = 
283     CmiRegisterHandler((CmiHandler)CldAckNoTaskHandler);
284
285   /* communication thread */
286   if (CmiMyRank() == CmiMyNodeSize())  return;
287
288   _stealonly1 = CmiGetArgFlagDesc(argv, "+stealonly1", "Charm++> Work Stealing, every time only steal 1 task");
289   
290   workstealingproactive= CmiGetArgFlagDesc(argv, "+workstealingproactive", "Charm++> Work Stealing, steal before going idle(threshold = 3)");
291
292   /* register idle handlers - when idle, keep asking work from neighbors */
293   if(CmiNumPes() > 1)
294     CcdCallOnConditionKeep(CcdPROCESSOR_BEGIN_IDLE,
295       (CcdVoidFn) CldBeginIdle, NULL);
296   if (CmiMyPe() == 0) 
297       CmiPrintf("Charm++> Work stealing is enabled. \n");
298   if(workstealingproactive && CmiMyPe() == 0)
299       CmiPrintf("Charm++> Steal work when load is fewer than 3. \n");
300 }
301
302
303 void CldModuleInit(char **argv)
304 {
305   CpvInitialize(int, CldHandlerIndex);
306   CpvInitialize(int, CldRelocatedMessages);
307   CpvInitialize(int, CldLoadBalanceMessages);
308   CpvInitialize(int, CldMessageChunks);
309   CpvAccess(CldHandlerIndex) = CmiRegisterHandler(CldHandler);
310   CpvAccess(CldRelocatedMessages) = CpvAccess(CldLoadBalanceMessages) = 
311   CpvAccess(CldMessageChunks) = 0;
312
313   CldModuleGeneralInit(argv);
314   CldGraphModuleInit(argv);
315
316   CpvAccess(CldLoadNotify) = 1;
317
318   CpvInitialize(int, isStealing);
319   CpvAccess(isStealing) = 0;
320 }
321
322 void CldCallback()
323 {}