Fixing bug in new multicast strategy.
[charm.git] / src / ck-com / OneTimeMulticastStrategy.C
1 /**
2    @addtogroup ComlibCharmStrategy
3    @{
4    @file
5
6 */
7
8
9 #include "OneTimeMulticastStrategy.h"
10 #include <string>
11 #include <set>
12 #include <vector>
13
14 //#define DEBUG 1
15
16 CkpvExtern(CkGroupID, cmgrID);
17
18 OneTimeMulticastStrategy::OneTimeMulticastStrategy()
19   : Strategy(), CharmStrategy() {
20   //  ComlibPrintf("OneTimeMulticastStrategy constructor\n");
21   setType(ARRAY_STRATEGY);
22 }
23
24 OneTimeMulticastStrategy::~OneTimeMulticastStrategy() {
25 }
26
27 void OneTimeMulticastStrategy::pup(PUP::er &p){
28   Strategy::pup(p);
29   CharmStrategy::pup(p);
30 }
31
32
33 /** Called when the user invokes the entry method on the delegated proxy. */
34 void OneTimeMulticastStrategy::insertMessage(CharmMessageHolder *cmsg){
35 #if DEBUG
36   CkPrintf("[%d] OneTimeMulticastStrategy::insertMessage\n", CkMyPe());
37   fflush(stdout);
38 #endif 
39
40   if(cmsg->dest_proc != IS_SECTION_MULTICAST && cmsg->sec_id == NULL) { 
41     CkAbort("OneTimeMulticastStrategy can only be used with an array section proxy");
42   }
43     
44   // Create a multicast message containing all information about remote destination objects 
45   int needSort = 0;
46   ComlibMulticastMsg * multMsg = sinfo.getNewMulticastMessage(cmsg, needSort, getInstance());
47     
48   // local multicast will re-extract a list of local destination objects (FIXME to make this more efficient)
49   localMulticast(cmsg);
50   
51   // The remote multicast method will send the message to the remote PEs, as specified in multMsg
52   remoteMulticast(multMsg, true);
53    
54   delete cmsg;    
55 }
56
57
58
59 /** Deliver the message to the local elements. */
60 void OneTimeMulticastStrategy::localMulticast(CharmMessageHolder *cmsg) {
61   double start = CmiWallTimer();
62   CkSectionID *sec_id = cmsg->sec_id;
63   CkVec< CkArrayIndexMax > localIndices;
64   sinfo.getLocalIndices(sec_id->_nElems, sec_id->_elems, sec_id->_cookie.aid, localIndices);
65   deliverToIndices(cmsg->getCharmMessage(), localIndices );
66   traceUserBracketEvent(10000, start, CmiWallTimer());
67 }
68
69
70
71
72
73 /** 
74     Forward multicast message to our successor processors in the spanning tree. 
75     Uses CmiSyncListSendAndFree for delivery to this strategy's OneTimeMulticastStrategy::handleMessage method.
76 */
77 void OneTimeMulticastStrategy::remoteMulticast(ComlibMulticastMsg * multMsg, bool rootPE) {
78   double start = CmiWallTimer();
79
80   envelope *env = UsrToEnv(multMsg);
81     
82   
83   /// The index into the PE list in the message
84   int myIndex = -10000; 
85   const int totalDestPEs = multMsg->nPes;
86   const int myPe = CkMyPe();
87   
88   // Find my index in the list of all destination PEs
89   if(rootPE){
90     myIndex = -1;
91   } else {
92     for (int i=0; i<totalDestPEs; ++i) {
93       if(multMsg->indicesCount[i].pe == myPe){
94         myIndex = i;
95         break;
96       }
97     }
98   }
99   
100   if(myIndex == -10000)
101     CkAbort("My PE was not found in the list of destination PEs in the ComlibMulticastMsg");
102   
103   int npes;
104   int *pelist = NULL;
105
106   if(totalDestPEs > 0)
107     determineNextHopPEs(totalDestPEs, multMsg->indicesCount, myIndex, pelist, npes );
108   else {
109     npes = 0;
110   }
111
112   if(npes == 0) {
113 #if DEBUG
114     CkPrintf("[%d] OneTimeMulticastStrategy::remoteMulticast is not forwarding to any other PEs\n", CkMyPe());
115 #endif
116     traceUserBracketEvent(10001, start, CmiWallTimer());
117     CmiFree(env);
118     return;
119   }
120   
121   CmiSetHandler(env, CkpvAccess(comlib_handler));
122   ((CmiMsgHeaderBasic *) env)->stratid = getInstance();  
123   CkPackMessage(&env);
124
125   double middle = CmiWallTimer();
126
127
128   //Collect Multicast Statistics
129   RECORD_SENDM_STATS(getInstance(), env->getTotalsize(), pelist, npes);
130   
131   CkAssert(npes > 0);
132   CmiSyncListSendAndFree(npes, pelist, env->getTotalsize(), (char*)env);
133   
134   delete[] pelist;
135
136   double end = CmiWallTimer();
137   traceUserBracketEvent(10001, start, middle);
138   traceUserBracketEvent(10002, middle, end);
139   
140 }
141
142
143
144 /** 
145     Receive an incoming multicast message(sent from OneTimeMulticastStrategy::remoteMulticast).
146     Deliver the message to all local elements 
147 */
148 void OneTimeMulticastStrategy::handleMessage(void *msg){
149 #if DEBUG
150   //  CkPrintf("[%d] OneTimeMulticastStrategy::handleMessage\n", CkMyPe());
151 #endif
152   envelope *env = (envelope *)msg;
153   CkUnpackMessage(&env);
154   
155   ComlibMulticastMsg* multMsg = (ComlibMulticastMsg*)EnvToUsr(env);
156   
157   // Don't use msg after this point. Instead use the unpacked env
158   
159   RECORD_RECV_STATS(getInstance(), env->getTotalsize(), env->getSrcPe());
160   
161   // Deliver to objects marked as local in the message
162   int localElems;
163   envelope *newenv;
164   CkArrayIndexMax *local_idx_list;  
165   sinfo.unpack(env, localElems, local_idx_list, newenv);
166   ComlibMulticastMsg *newmsg = (ComlibMulticastMsg *)EnvToUsr(newenv);  
167   deliverToIndices(newmsg, localElems, local_idx_list );
168   
169   // Forward on to other processors if necessary
170   remoteMulticast(multMsg, false);
171
172 }
173
174
175
176
177 void OneTimeMulticastStrategy::determineNextHopPEs(const int totalDestPEs, const ComlibMulticastIndexCount* destPEs, const int myIndex, int * &pelist, int &npes) {
178   if(myIndex==-1){
179     // We are at a root node of the spanning tree. 
180     // We will forward the message to all other PEs in the destination list.
181     npes = totalDestPEs;
182     
183     pelist = new int[npes];
184     for (int i=0; i<npes; ++i) {
185       pelist[i] = destPEs[i].pe;
186     }
187   } else {
188     // We are at a leaf node of the spanning tree. 
189     npes = 0;
190   }
191   
192 }
193
194
195 void OneTimeRingMulticastStrategy::determineNextHopPEs(const int totalDestPEs, const ComlibMulticastIndexCount* destPEs, const int myIndex, int * &pelist, int &npes) {
196   const int myPe = CkMyPe();
197
198   if(myIndex == totalDestPEs-1){
199     // Final PE won't send to anyone
200     npes = 0;
201     return;
202   } else {
203     // All non-final PEs will send to next PE in list
204     npes = 1;
205     pelist = new int[1];
206     pelist[0] = destPEs[myIndex+1].pe;
207   }
208
209 }
210
211
212 void OneTimeTreeMulticastStrategy::determineNextHopPEs(const int totalDestPEs, const ComlibMulticastIndexCount* destPEs, const int myIndex, int * &pelist, int &npes){
213   const int myPe = CkMyPe();
214   
215   // The logical indices start at 0 = root node. Logical index i corresponds to the entry i+1 in the array of PEs in the message
216   
217   int sendLogicalIndexStart = degree*(myIndex+1) + 1;       // inclusive
218   int sendLogicalIndexEnd = sendLogicalIndexStart + degree - 1;   // inclusive
219   
220   if(sendLogicalIndexEnd-1 >= totalDestPEs){
221     sendLogicalIndexEnd = totalDestPEs;
222   }
223
224   int numSend = sendLogicalIndexEnd - sendLogicalIndexStart + 1;
225   if(numSend <= 0){
226     npes = 0;
227     return;
228   }
229  
230 #if DEBUG
231   if(numSend > 0)
232     CkPrintf("Tree logical index %d sending to logical %d to %d (totalDestPEs excluding root=%d)  numSend=%d\n",
233              myIndex+1, sendLogicalIndexStart, sendLogicalIndexEnd, totalDestPEs, numSend);
234 #endif
235
236   npes = numSend;
237   pelist = new int[npes];
238   
239   for(int i=0;i<numSend;i++){
240     CkAssert(sendLogicalIndexStart-1+i < totalDestPEs);
241     pelist[i] = destPEs[sendLogicalIndexStart-1+i].pe;
242 #if DEBUG
243     CkPrintf("Tree logical index %d sending to PE %d\n", myIndex+1, pelist[i]);
244 #endif
245     CkAssert(pelist[i] < CkNumPes());
246   }
247   
248 }
249
250
251 /** Find a unique representative PE for a node containing pe, with the restriction that the returned PE is in the list destPEs. */
252 int getFirstPeOnPhysicalNodeFromList(int pe, const int totalDestPEs, const ComlibMulticastIndexCount* destPEs){
253   int num;
254   int *nodePeList;
255   CmiGetPesOnPhysicalNode(pe, &nodePeList, &num);
256   
257   for(int i=0;i<num;i++){
258     // Scan destPEs for the pe
259     int p = nodePeList[i];
260     
261     for(int j=0;j<totalDestPEs;j++){
262       if(p == destPEs[j].pe){
263         // found the representative PE for the node that is in the destPEs list
264         return p;
265       }
266     }
267   }
268   
269   CkAbort("ERROR: Could not find an entry for pe in destPEs list.\n");
270   return -1;
271 }
272
273
274 /** Find a unique representative PE for a node containing pe, with the restriction that the returned PE is in the list destPEs. */
275 int getNthPeOnPhysicalNodeFromList(int n, int pe, const int totalDestPEs, const ComlibMulticastIndexCount* destPEs){
276   int num;
277   int *nodePeList;
278   CmiGetPesOnPhysicalNode(pe, &nodePeList, &num);
279   
280   int count = 0;
281   int lastFound = -1;
282   
283   // Foreach PE on this physical node
284   for(int i=0;i<num;i++){
285     int p = nodePeList[i];
286     
287     // Scan destPEs for the pe
288     for(int j=0;j<totalDestPEs;j++){
289       if(p == destPEs[j].pe){
290         lastFound = p;
291         if(count==n)
292           return p;
293         count++;
294       }
295     }
296   }
297   
298   if(lastFound != -1)
299     return lastFound;
300
301   CkAbort("ERROR: Could not find an entry for pe in destPEs list.\n");
302   return -1;
303 }
304
305
306 /** List all the PEs from the list that share the physical node */ 
307 std::vector<int> getPesOnPhysicalNodeFromList(int pe, const int totalDestPEs, const ComlibMulticastIndexCount* destPEs){ 
308    
309   std::vector<int> result; 
310  
311   int num; 
312   int *nodePeList; 
313   CmiGetPesOnPhysicalNode(pe, &nodePeList, &num); 
314   
315   for(int i=0;i<num;i++){ 
316     // Scan destPEs for the pe 
317     int p = nodePeList[i]; 
318     for(int j=0;j<totalDestPEs;j++){ 
319       if(p == destPEs[j].pe){ 
320         // found the representative PE for the node that is in the
321         // destPEs list 
322         result.push_back(p); 
323         break; 
324       } 
325     } 
326   } 
327   
328   return result; 
329 }
330
331
332
333 /** List all the other PEs from the list that share the physical node */
334 std::vector<int> getOtherPesOnPhysicalNodeFromList(int pe, const int totalDestPEs, const ComlibMulticastIndexCount* destPEs){
335   
336   std::vector<int> result;
337
338   int num;
339   int *nodePeList;
340   CmiGetPesOnPhysicalNode(pe, &nodePeList, &num);
341   
342   for(int i=0;i<num;i++){
343     // Scan destPEs for the pe
344     int p = nodePeList[i];
345     if(p != pe){
346       for(int j=0;j<totalDestPEs;j++){
347         if(p == destPEs[j].pe){
348           // found the representative PE for the node that is in the destPEs list
349           result.push_back(p);
350           break;
351         }
352       }
353     }
354   }
355   
356   return result;
357 }
358
359
360 void OneTimeNodeTreeMulticastStrategy::determineNextHopPEs(const int totalDestPEs, const ComlibMulticastIndexCount* destPEs, const int myIndex, int * &pelist, int &npes){
361   const int myPe = CkMyPe();
362
363   std::set<int> nodePERepresentatives;
364   
365   // create a list of PEs, with one for each node to which the message must be sent
366   for(int i=0; i<totalDestPEs; i++){
367     int pe = destPEs[i].pe;
368     int representative = getFirstPeOnPhysicalNodeFromList(pe, totalDestPEs, destPEs);
369     nodePERepresentatives.insert(representative);    
370   }
371   
372   int numRepresentativePEs = nodePERepresentatives.size();
373   
374   int repForMyPe=-1;
375   if(myIndex != -1)
376     repForMyPe = getFirstPeOnPhysicalNodeFromList(CkMyPe(), totalDestPEs, destPEs);
377   
378 #if DEBUG
379   CkPrintf("[%d] Multicasting to %d PEs on %d physical nodes  repForMyPe=%d\n", CkMyPe(), totalDestPEs, numRepresentativePEs, repForMyPe);
380   fflush(stdout);
381 #endif
382   
383   // If this PE is part of the multicast tree, then it should forward the message along
384   if(CkMyPe() == repForMyPe || myIndex == -1){
385     // I am an internal node in the multicast tree
386     
387     // flatten the data structure for nodePERepresentatives
388     int *repPeList = new int[numRepresentativePEs];
389     int myRepIndex = -1;
390     std::set<int>::iterator iter;
391     int p=0;
392     for(iter=nodePERepresentatives.begin(); iter != nodePERepresentatives.end(); iter++){
393       repPeList[p] = *iter;
394       if(*iter == repForMyPe)
395         myRepIndex = p;
396       p++;
397     }
398     CkAssert(myRepIndex >=0 || myIndex==-1);
399       
400     // The logical indices start at 0 = root node. Logical index i corresponds to the entry i+1 in the array of PEs in the message
401     int sendLogicalIndexStart = degree*(myRepIndex+1) + 1;       // inclusive
402     int sendLogicalIndexEnd = sendLogicalIndexStart + degree - 1;   // inclusive
403     
404     if(sendLogicalIndexEnd-1 >= numRepresentativePEs){
405       sendLogicalIndexEnd = numRepresentativePEs;
406     }
407     
408     int numSendTree = sendLogicalIndexEnd - sendLogicalIndexStart + 1;
409     if(numSendTree < 0)
410       numSendTree = 0;
411     
412     std::vector<int> otherLocalPes = getOtherPesOnPhysicalNodeFromList(CkMyPe(), totalDestPEs, destPEs);
413     int numSendLocal;
414     if(myIndex == -1)
415       numSendLocal = 0;
416     else 
417       numSendLocal = otherLocalPes.size();
418     
419     
420
421 #if DEBUG
422     CkPrintf("[%d] numSendTree=%d numSendLocal=%d sendLogicalIndexStart=%d sendLogicalIndexEnd=%d\n", CkMyPe(), numSendTree, numSendLocal,  sendLogicalIndexStart, sendLogicalIndexEnd);
423     fflush(stdout);
424 #endif
425
426     int numSend = numSendTree + numSendLocal;
427     if(numSend <= 0){
428       npes = 0;
429       return;
430     }
431     
432     npes = numSend;
433     pelist = new int[npes];
434   
435     for(int i=0;i<numSendTree;i++){
436       CkAssert(sendLogicalIndexStart-1+i < numRepresentativePEs);
437       pelist[i] = repPeList[sendLogicalIndexStart-1+i];
438       CkAssert(pelist[i] < CkNumPes() && pelist[i] >= 0);
439     }
440     
441     delete[] repPeList;
442     repPeList = NULL;
443
444     for(int i=0;i<numSendLocal;i++){
445       pelist[i+numSendTree] = otherLocalPes[i];
446       CkAssert(pelist[i] < CkNumPes() && pelist[i] >= 0);
447     }
448     
449     
450 #if DEBUG
451     char buf[1024];
452     sprintf(buf, "PE %d is sending to Remote Node PEs: ", CkMyPe() );
453     for(int i=0;i<numSend;i++){
454       if(i==numSendTree)
455         sprintf(buf+strlen(buf), " and Local To Node PEs: ", pelist[i]);
456
457       sprintf(buf+strlen(buf), "%d ", pelist[i]);
458     }    
459     CkPrintf("%s\n", buf);
460     fflush(stdout);
461 #endif
462         
463   } else {
464     // We are a leaf PE
465     npes = 0;
466     return;
467   }
468
469   
470   
471 }
472
473
474 void OneTimeNodeTreeRingMulticastStrategy::determineNextHopPEs(const int totalDestPEs, const ComlibMulticastIndexCount* destPEs, const int myIndex, int * &pelist, int &npes){
475   const int myPe = CkMyPe();
476
477   std::set<int> nodePERepresentatives;
478   
479   // create a list of PEs, with one for each node to which the message must be sent
480   for(int i=0; i<totalDestPEs; i++){
481     int pe = destPEs[i].pe;
482     int representative = getFirstPeOnPhysicalNodeFromList(pe, totalDestPEs, destPEs);
483     nodePERepresentatives.insert(representative);    
484   }
485   
486   int numRepresentativePEs = nodePERepresentatives.size();
487   
488   int repForMyPe=-1;
489   if(myIndex != -1)
490     repForMyPe = getFirstPeOnPhysicalNodeFromList(CkMyPe(), totalDestPEs, destPEs);
491   
492 #if DEBUG
493   CkPrintf("[%d] Multicasting to %d PEs on %d physical nodes  repForMyPe=%d\n", CkMyPe(), totalDestPEs, numRepresentativePEs, repForMyPe);
494   fflush(stdout);
495 #endif
496   
497   // If this PE is part of the multicast tree, then it should forward the message along
498   if(CkMyPe() == repForMyPe || myIndex == -1){
499     // I am an internal node in the multicast tree
500     
501     // flatten the data structure for nodePERepresentatives
502     int *repPeList = new int[numRepresentativePEs];
503     int myRepIndex = -1;
504     std::set<int>::iterator iter;
505     int p=0;
506     for(iter=nodePERepresentatives.begin(); iter != nodePERepresentatives.end(); iter++){
507       repPeList[p] = *iter;
508       if(*iter == repForMyPe)
509         myRepIndex = p;
510       p++;
511     }
512     CkAssert(myRepIndex >=0 || myIndex==-1);
513       
514     // The logical indices start at 0 = root node. Logical index i corresponds to the entry i+1 in the array of PEs in the message
515     int sendLogicalIndexStart = degree*(myRepIndex+1) + 1;       // inclusive
516     int sendLogicalIndexEnd = sendLogicalIndexStart + degree - 1;   // inclusive
517     
518     if(sendLogicalIndexEnd-1 >= numRepresentativePEs){
519       sendLogicalIndexEnd = numRepresentativePEs;
520     }
521     
522     int numSendTree = sendLogicalIndexEnd - sendLogicalIndexStart + 1;
523     if(numSendTree < 0)
524       numSendTree = 0;
525
526
527     // Send in a ring to the PEs on this node
528     std::vector<int> otherLocalPes = getOtherPesOnPhysicalNodeFromList(CkMyPe(), totalDestPEs, destPEs);
529     int numSendLocal = 0;
530     if(myIndex == -1)
531       numSendLocal = 0;
532     else {
533       if(otherLocalPes.size() > 0)
534         numSendLocal = 1;
535       else
536         numSendLocal = 0;
537     }
538     
539
540 #if DEBUG
541     CkPrintf("[%d] numSendTree=%d numSendLocal=%d sendLogicalIndexStart=%d sendLogicalIndexEnd=%d\n", CkMyPe(), numSendTree, numSendLocal,  sendLogicalIndexStart, sendLogicalIndexEnd);
542     fflush(stdout);
543 #endif
544
545     int numSend = numSendTree + numSendLocal;
546     if(numSend <= 0){
547       npes = 0;
548       return;
549     }
550     
551     npes = numSend;
552     pelist = new int[npes];
553   
554     for(int i=0;i<numSendTree;i++){
555       CkAssert(sendLogicalIndexStart-1+i < numRepresentativePEs);
556       pelist[i] = repPeList[sendLogicalIndexStart-1+i];
557       CkAssert(pelist[i] < CkNumPes() && pelist[i] >= 0);
558     }
559     
560     delete[] repPeList;
561     repPeList = NULL;
562
563     for(int i=0;i<numSendLocal;i++){
564       pelist[i+numSendTree] = otherLocalPes[i];
565       CkAssert(pelist[i] < CkNumPes() && pelist[i] >= 0);
566     }
567     
568     
569 #if DEBUG
570     char buf[1024];
571     sprintf(buf, "PE %d is sending to Remote Node PEs: ", CkMyPe() );
572     for(int i=0;i<numSend;i++){
573       if(i==numSendTree)
574         sprintf(buf+strlen(buf), " and Local To Node PEs: ", pelist[i]);
575
576       sprintf(buf+strlen(buf), "%d ", pelist[i]);
577     }    
578     CkPrintf("%s\n", buf);
579     fflush(stdout);
580 #endif
581         
582   } else {
583     // We are a leaf PE, so forward in a ring to the PEs on this node
584     const std::vector<int> otherLocalPes = getPesOnPhysicalNodeFromList(CkMyPe(), totalDestPEs, destPEs);
585     
586     npes = 0;
587     pelist = new int[1];
588     
589     //    CkPrintf("[%d] otherLocalPes.size=%d\n", CkMyPe(), otherLocalPes.size() ); 
590     const int numOthers = otherLocalPes.size() ;
591     
592     for(int i=0;i<numOthers;i++){
593       if(otherLocalPes[i] == CkMyPe()){
594         // found me in the PE list for this node
595         if(i+1<otherLocalPes.size()){
596           // If we have a successor in the ring
597           pelist[0] = otherLocalPes[i+1];
598           npes = 1;
599         }
600       }
601     }
602     
603     
604 #if DEBUG
605     if(npes==0)
606       CkPrintf("[%d] At end of ring\n", CkMyPe() );
607     else
608       CkPrintf("[%d] sending along ring to %d\n", CkMyPe(), pelist[0] );
609     
610     fflush(stdout);
611 #endif
612     
613     
614   }
615   
616   
617   
618 }
619
620 /*@}*/