f3b53cc5748e31c411f07c98f89be2187d153109
[charm.git] / src / ck-ldb / RefinerComm.C
1 /**
2  * \addtogroup CkLdb
3 */
4 /*@{*/
5
6 /** This code is derived from RefineLB.C, and RefineLB.C should
7  be rewritten to use this, so there is no code duplication
8 */
9
10 #include "elements.h"
11 #include "ckheap.h"
12 #include "RefinerComm.h"
13
14
15 void RefinerComm::create(int count, BaseLB::LDStats* _stats, int* procs)
16 {
17   int i;
18   stats = _stats;
19   Refiner::create(count, _stats, procs);
20
21   for (i=0; i<stats->n_comm; i++) 
22   {
23         LDCommData &comm = stats->commData[i];
24         if (!comm.from_proc()) {
25           // out going message
26           int computeIdx = stats->getSendHash(comm);
27           CmiAssert(computeIdx >= 0 && computeIdx < numComputes);
28           computes[computeIdx].sendmessages.push_back(i);
29         }
30
31         // FIXME: only obj msg here
32         // incoming messages
33         if (comm.receiver.get_type() == LD_OBJ_MSG)  {
34           int computeIdx = stats->getRecvHash(comm);
35           CmiAssert(computeIdx >= 0 && computeIdx < numComputes);
36           computes[computeIdx].recvmessages.push_back(i);
37         }
38   }
39 }
40
41 void RefinerComm::computeAverage()
42 {
43   int i;
44   double total = 0.;
45   for (i=0; i<numComputes; i++) total += computes[i].load;
46
47   for (i=0; i<P; i++) {
48     if (processors[i].available == CmiTrue) {
49         total += processors[i].backgroundLoad;
50         total += commTable->overheadOnPe(i);
51     }
52   }
53
54   averageLoad = total/numAvail;
55 }
56
57 // compute the initial per processor communication overhead
58 void RefinerComm::processorCommCost()
59 {
60   int i;
61
62   for (int cidx=0; cidx < stats->n_comm; cidx++) {
63     LDCommData& cdata = stats->commData[cidx];
64     int senderPE = -1, receiverPE = -1;
65     if (cdata.from_proc())
66       senderPE = cdata.src_proc;
67     else {
68       int idx = stats->getSendHash(cdata);
69       CmiAssert(idx != -1);
70       senderPE = computes[idx].oldProcessor;    // object's original processor
71     }
72     CmiAssert(senderPE != -1);
73     int ctype = cdata.receiver.get_type();
74     if (ctype==LD_PROC_MSG || ctype==LD_OBJ_MSG) {
75       if (ctype==LD_PROC_MSG)
76         receiverPE = cdata.receiver.proc();
77       else {    // LD_OBJ_MSG
78         int idx = stats->getRecvHash(cdata);
79         CmiAssert(idx != -1);
80         receiverPE = computes[idx].oldProcessor;
81       }
82       CmiAssert(receiverPE != -1);
83       if(senderPE != receiverPE)
84       {
85         commTable->increase(true, senderPE, cdata.messages, cdata.bytes);
86         commTable->increase(false, receiverPE, cdata.messages, cdata.bytes);
87       }
88     }
89     else if (ctype == LD_OBJLIST_MSG) {
90       int nobjs;
91       LDObjKey *objs = cdata.receiver.get_destObjs(nobjs);
92       for (i=0; i<nobjs; i++) {
93         int idx = stats->getHash(objs[i]);
94         if(idx == -1)
95              if (_lb_args.migObjOnly()) continue;
96              else CkAbort("Error in search\n");
97         receiverPE = computes[idx].oldProcessor;
98         CmiAssert(receiverPE != -1);
99         if(senderPE != receiverPE)
100         {
101           commTable->increase(true, senderPE, cdata.messages, cdata.bytes);
102           commTable->increase(false, receiverPE, cdata.messages, cdata.bytes);
103         }
104       }
105     }
106   }
107   // recalcualte the cpu load
108   for (i=0; i<P; i++) 
109   {
110     processorInfo *p = &processors[i];
111     p->load = p->computeLoad + p->backgroundLoad + commTable->overheadOnPe(i);
112   }
113 }
114
115 void RefinerComm::assign(computeInfo *c, int processor)
116 {
117   assign(c, &(processors[processor]));
118 }
119
120 void RefinerComm::assign(computeInfo *c, processorInfo *p)
121 {
122    int oldProc = c->processor;
123    c->processor = p->Id;
124    p->computeSet->insert((InfoRecord *) c);
125    p->computeLoad += c->load;
126 //   p->load = p->computeLoad + p->backgroundLoad;
127    // add communication cost
128    Messages m;
129    objCommCost(c->Id, p->Id, m);
130    commTable->increase(true, p->Id, m.msgSent, m.byteSent);
131    commTable->increase(false, p->Id, m.msgRecv, m.byteRecv);
132
133 //   CmiPrintf("Assign %d to %d commCost: %d %d %d %d \n", c->Id, p->Id, byteSent,msgSent,byteRecv,msgRecv);
134
135    commAffinity(c->Id, p->Id, m);
136    commTable->increase(false, p->Id, -m.msgSent, -m.byteSent);
137    commTable->increase(true, p->Id, -m.msgRecv, -m.byteRecv);   // reverse
138
139 //   CmiPrintf("Assign %d to %d commAffinity: %d %d %d %d \n", c->Id, p->Id, -byteSent,-msgSent,-byteRecv,-msgRecv);
140
141    p->load = p->computeLoad + p->backgroundLoad + commTable->overheadOnPe(p->Id);
142 }
143
144 void  RefinerComm::deAssign(computeInfo *c, processorInfo *p)
145 {
146 //   c->processor = -1;
147    p->computeSet->remove(c);
148    p->computeLoad -= c->load;
149 //   p->load = p->computeLoad + p->backgroundLoad;
150    Messages m;
151    objCommCost(c->Id, p->Id, m);
152    commTable->increase(true, p->Id, -m.msgSent, -m.byteSent);
153    commTable->increase(false, p->Id, -m.msgRecv, -m.byteRecv);
154    
155    commAffinity(c->Id, p->Id, m);
156    commTable->increase(true, p->Id, m.msgSent, m.byteSent);
157    commTable->increase(false, p->Id, m.msgRecv, m.byteRecv);
158
159    p->load = p->computeLoad + p->backgroundLoad + commTable->overheadOnPe(p->Id);
160 }
161
162 // how much communication from compute c  to pe
163 // byteSent, msgSent are messages from object c to pe p
164 // byteRecv, msgRecv are messages from pe p to obejct c
165 void RefinerComm::commAffinity(int c, int pe, Messages &m)
166 {
167   int i;
168   m.clear();
169   computeInfo &obj = computes[c];
170
171   int nSendMsgs = obj.sendmessages.length();
172   for (i=0; i<nSendMsgs; i++) {
173     LDCommData &cdata = stats->commData[obj.sendmessages[i]];
174     bool sendtope = false;
175     if (cdata.receiver.get_type() == LD_OBJ_MSG) {
176       int recvCompute = stats->getRecvHash(cdata);
177       int recvProc = computes[recvCompute].processor;
178       if (recvProc != -1 && recvProc == pe) sendtope = true;
179     }
180     else if (cdata.receiver.get_type() == LD_OBJLIST_MSG) {  // multicast
181       int nobjs;
182       LDObjKey *recvs = cdata.receiver.get_destObjs(nobjs);
183       for (int j=0; j<nobjs; j++) {
184         int recvCompute = stats->getHash(recvs[j]);
185         int recvProc = computes[recvCompute].processor; // FIXME
186         if (recvProc != -1 && recvProc == pe) { sendtope = true; continue; }
187       }  
188     }
189     if (sendtope) {
190       m.byteSent += cdata.bytes;
191       m.msgSent += cdata.messages;
192     }
193   }  // end of for
194
195   int nRecvMsgs = obj.recvmessages.length();
196   for (i=0; i<nRecvMsgs; i++) {
197     LDCommData &cdata = stats->commData[obj.recvmessages[i]];
198     int sendProc;
199     if (cdata.from_proc()) {
200       sendProc = cdata.src_proc;
201     }
202     else {
203       int sendCompute = stats->getSendHash(cdata);
204       sendProc = computes[sendCompute].processor;
205     }
206     if (sendProc != -1 && sendProc == pe) {
207       m.byteRecv += cdata.bytes;
208       m.msgRecv += cdata.messages;
209     }
210   }  // end of for
211 }
212
213 // assume c is on pe, how much comm overhead it will be?
214 void RefinerComm::objCommCost(int c, int pe, Messages &m)
215 {
216   int i;
217   m.clear();
218   computeInfo &obj = computes[c];
219
220   // find out send overhead for every outgoing message that has receiver
221   // not same as pe
222   int nSendMsgs = obj.sendmessages.length();
223   for (i=0; i<nSendMsgs; i++) {
224     LDCommData &cdata = stats->commData[obj.sendmessages[i]];
225     bool diffPe = false;
226     if (cdata.receiver.get_type() == LD_PROC_MSG) {
227       CmiAssert(0);
228     }
229     if (cdata.receiver.get_type() == LD_OBJ_MSG) {
230       int recvCompute = stats->getRecvHash(cdata);
231       int recvProc = computes[recvCompute].processor;
232       if (recvProc!= -1 && recvProc != pe) diffPe = true;
233     }
234     else if (cdata.receiver.get_type() == LD_OBJLIST_MSG) {  // multicast
235       int nobjs;
236       LDObjKey *recvs = cdata.receiver.get_destObjs(nobjs);
237       for (int j=0; j<nobjs; j++) {
238         int recvCompute = stats->getHash(recvs[j]);
239         int recvProc = computes[recvCompute].processor; // FIXME
240         if (recvProc!= -1 && recvProc != pe) { diffPe = true; }
241       }  
242     }
243     if (diffPe) {
244       m.byteSent += cdata.bytes;
245       m.msgSent += cdata.messages;
246     }
247   }  // end of for
248
249   // find out recv overhead for every incoming message that has sender
250   // not same as pe
251   int nRecvMsgs = obj.recvmessages.length();
252   for (i=0; i<nRecvMsgs; i++) {
253     LDCommData &cdata = stats->commData[obj.recvmessages[i]];
254     bool diffPe = false;
255     if (cdata.from_proc()) {
256       if (cdata.src_proc != pe) diffPe = true;
257     }
258     else {
259       int sendCompute = stats->getSendHash(cdata);
260       int sendProc = computes[sendCompute].processor;
261       if (sendProc != -1 && sendProc != pe) diffPe = true;
262     }
263     if (diffPe) {       // sender is not pe
264       m.byteRecv += cdata.bytes;
265       m.msgRecv += cdata.messages;
266     }
267   }  // end of for
268 }
269
270 int RefinerComm::refine()
271 {
272   int i;
273   int finish = 1;
274
275   maxHeap *heavyProcessors = new maxHeap(P);
276   Set *lightProcessors = new Set();
277   for (i=0; i<P; i++) {
278     if (isHeavy(&processors[i])) {  
279       //      CkPrintf("Processor %d is HEAVY: load:%f averageLoad:%f!\n",
280      //                i, processors[i].load, averageLoad);
281       heavyProcessors->insert((InfoRecord *) &(processors[i]));
282     } else if (isLight(&processors[i])) {
283       //      CkPrintf("Processor %d is LIGHT: load:%f averageLoad:%f!\n",
284      //                i, processors[i].load, averageLoad);
285       lightProcessors->insert((InfoRecord *) &(processors[i]));
286     }
287   }
288   int done = 0;
289
290   while (!done) {
291     double bestSize, bestComm;
292     computeInfo *bestCompute;
293     processorInfo *bestP;
294     
295     processorInfo *donor = (processorInfo *) heavyProcessors->deleteMax();
296     if (!donor) break;
297
298     //find the best pair (c,receiver)
299     Iterator nextProcessor;
300     processorInfo *p = (processorInfo *) 
301       lightProcessors->iterator((Iterator *) &nextProcessor);
302     bestSize = 0;
303     bestComm = -1e8;
304     bestP = NULL;
305     bestCompute = NULL;
306
307     while (p) {
308       Iterator nextCompute;
309       nextCompute.id = 0;
310       computeInfo *c = (computeInfo *) 
311         donor->computeSet->iterator((Iterator *)&nextCompute);
312       //CmiPrintf("Considering Procsessor : %d with load: %f for donor: %d\n", p->Id, p->load, donor->Id);
313       while (c) {
314         if (!c->migratable) {
315           nextCompute.id++;
316           c = (computeInfo *) 
317             donor->computeSet->next((Iterator *)&nextCompute);
318           continue;
319         }
320         //CkPrintf("c->load: %f p->load:%f overLoad*averageLoad:%f \n",
321         //c->load, p->load, overLoad*averageLoad);
322         Messages m;
323         objCommCost(c->Id, donor->Id, m);
324         double commcost = m.cost();
325         commAffinity(c->Id, p->Id, m);
326         double commgain = m.cost();;
327
328         //CmiPrintf("Considering Compute: %d with load %f commcost:%f commgain:%f\n", c->Id, c->load, commcost, commgain);
329         if ( c->load + p->load + commcost - commgain < overLoad*averageLoad) {
330           //CmiPrintf("[%d] comm gain %f bestSize:%f\n", c->Id, commgain, bestSize);
331           if(c->load + commcost - commgain > bestSize) {
332             bestSize = c->load + commcost - commgain;
333             bestCompute = c;
334             bestP = p;
335           }
336         }
337         nextCompute.id++;
338         c = (computeInfo *) 
339           donor->computeSet->next((Iterator *)&nextCompute);
340       }
341       p = (processorInfo *) 
342         lightProcessors->next((Iterator *) &nextProcessor);
343     }
344
345     if (bestCompute) {
346       if (_lb_args.debug())
347         CkPrintf("Assign: [%d] with load: %f from %d to %d \n",
348                bestCompute->Id, bestCompute->load, 
349                donor->Id, bestP->Id);
350       deAssign(bestCompute, donor);      
351       assign(bestCompute, bestP);
352
353       // show the load
354       if (_lb_args.debug())  printLoad();
355
356       // update commnication
357       computeAverage();
358       delete heavyProcessors;
359       delete lightProcessors;
360       heavyProcessors = new maxHeap(P);
361       lightProcessors = new Set();
362       for (i=0; i<P; i++) {
363         if (isHeavy(&processors[i])) {  
364           //      CkPrintf("Processor %d is HEAVY: load:%f averageLoad:%f!\n",
365           //           i, processors[i].load, averageLoad);
366           heavyProcessors->insert((InfoRecord *) &(processors[i]));
367         } else if (isLight(&processors[i])) {
368           lightProcessors->insert((InfoRecord *) &(processors[i]));
369         }
370       }
371       if (_lb_args.debug()) CmiPrintf("averageLoad after assignment: %f\n", averageLoad);
372     } else {
373       finish = 0;
374       break;
375     }
376
377
378 /*
379     if (bestP->load > averageLoad)
380       lightProcessors->remove(bestP);
381     
382     if (isHeavy(donor))
383       heavyProcessors->insert((InfoRecord *) donor);
384     else if (isLight(donor))
385       lightProcessors->insert((InfoRecord *) donor);
386 */
387   }  
388
389   delete heavyProcessors;
390   delete lightProcessors;
391
392   return finish;
393 }
394
395 void RefinerComm::Refine(int count, BaseLB::LDStats* stats, 
396                      int* cur_p, int* new_p)
397 {
398   //  CkPrintf("[%d] Refiner strategy\n",CkMyPe());
399
400   P = count;
401   numComputes = stats->n_objs;
402   computes = new computeInfo[numComputes];
403   processors = new processorInfo[count];
404   commTable = new CommTable(P);
405
406   // fill communication hash table
407   stats->makeCommHash();
408
409   create(count, stats, cur_p);
410
411   int i;
412   for (i=0; i<numComputes; i++)
413     assign((computeInfo *) &(computes[i]),
414            (processorInfo *) &(processors[computes[i].oldProcessor]));
415
416   commTable->clear();
417
418   // recalcualte the cpu load
419   processorCommCost();
420
421   removeComputes();
422   if (_lb_args.debug())  printLoad();
423
424   computeAverage();
425   if (_lb_args.debug()) CmiPrintf("averageLoad: %f\n", averageLoad);
426
427   multirefine();
428
429   for (int pe=0; pe < P; pe++) {
430     Iterator nextCompute;
431     nextCompute.id = 0;
432     computeInfo *c = (computeInfo *)
433       processors[pe].computeSet->iterator((Iterator *)&nextCompute);
434     while(c) {
435       new_p[c->Id] = c->processor;
436 //      if (c->oldProcessor != c->processor)
437 //      CkPrintf("Refiner::Refine: from %d to %d\n", c->oldProcessor, c->processor);
438       nextCompute.id++;
439       c = (computeInfo *) processors[pe].computeSet->
440                      next((Iterator *)&nextCompute);
441     }
442   }
443
444   delete [] computes;
445   delete [] processors;
446   delete commTable;
447 }
448
449 RefinerComm::CommTable::CommTable(int P)
450 {
451   count = P;
452   msgSentCount = new int[P]; // # of messages sent by each PE
453   msgRecvCount = new int[P]; // # of messages received by each PE
454   byteSentCount = new int[P];// # of bytes sent by each PE
455   byteRecvCount = new int[P];// # of bytes reeived by each PE
456   clear();
457 }
458
459 RefinerComm::CommTable::~CommTable()
460 {
461   delete [] msgSentCount;
462   delete [] msgRecvCount;
463   delete [] byteSentCount;
464   delete [] byteRecvCount;
465 }
466
467 void RefinerComm::CommTable::clear()
468 {
469   for(int i = 0; i < count; i++)
470     msgSentCount[i] = msgRecvCount[i] = byteSentCount[i] = byteRecvCount[i] = 0;
471 }
472
473 void RefinerComm::CommTable::increase(bool issend, int pe, int msgs, int bytes)
474 {
475   if (issend) {
476     msgSentCount[pe] += msgs;
477     byteSentCount[pe] += bytes;
478   }
479   else {
480     msgRecvCount[pe] += msgs;
481     byteRecvCount[pe] += bytes;
482   }
483 }
484
485 double RefinerComm::CommTable::overheadOnPe(int pe)
486 {
487   return msgRecvCount[pe]  * PER_MESSAGE_RECV_OVERHEAD +
488          msgSentCount[pe]  * _lb_args.alpha() +
489          byteRecvCount[pe] * PER_BYTE_RECV_OVERHEAD +
490          byteSentCount[pe] * _lb_args.beeta();
491 }
492
493 /*@}*/