Adding a new node aware multicast strategy that sends to PEs within a node along...
[charm.git] / src / ck-com / OneTimeMulticastStrategy.C
1 /**
2    @addtogroup ComlibCharmStrategy
3    @{
4    @file
5
6 */
7
8
9 #include "OneTimeMulticastStrategy.h"
10 #include <string>
11 #include <set>
12 #include <vector>
13
14 //#define DEBUG 1
15
16 CkpvExtern(CkGroupID, cmgrID);
17
18 OneTimeMulticastStrategy::OneTimeMulticastStrategy()
19   : Strategy(), CharmStrategy() {
20   //  ComlibPrintf("OneTimeMulticastStrategy constructor\n");
21   setType(ARRAY_STRATEGY);
22 }
23
24 OneTimeMulticastStrategy::~OneTimeMulticastStrategy() {
25 }
26
27 void OneTimeMulticastStrategy::pup(PUP::er &p){
28   Strategy::pup(p);
29   CharmStrategy::pup(p);
30 }
31
32
33 /** Called when the user invokes the entry method on the delegated proxy. */
34 void OneTimeMulticastStrategy::insertMessage(CharmMessageHolder *cmsg){
35 #if DEBUG
36   CkPrintf("[%d] OneTimeMulticastStrategy::insertMessage\n", CkMyPe());
37   fflush(stdout);
38 #endif 
39
40   if(cmsg->dest_proc != IS_SECTION_MULTICAST && cmsg->sec_id == NULL) { 
41     CkAbort("OneTimeMulticastStrategy can only be used with an array section proxy");
42   }
43     
44   // Create a multicast message containing all information about remote destination objects 
45   int needSort = 0;
46   ComlibMulticastMsg * multMsg = sinfo.getNewMulticastMessage(cmsg, needSort, getInstance());
47     
48   // local multicast will re-extract a list of local destination objects (FIXME to make this more efficient)
49   localMulticast(cmsg);
50   
51   // The remote multicast method will send the message to the remote PEs, as specified in multMsg
52   remoteMulticast(multMsg, true);
53    
54   delete cmsg;    
55 }
56
57
58
59 /** Deliver the message to the local elements. */
60 void OneTimeMulticastStrategy::localMulticast(CharmMessageHolder *cmsg) {
61   double start = CmiWallTimer();
62   CkSectionID *sec_id = cmsg->sec_id;
63   CkVec< CkArrayIndexMax > localIndices;
64   sinfo.getLocalIndices(sec_id->_nElems, sec_id->_elems, sec_id->_cookie.aid, localIndices);
65   deliverToIndices(cmsg->getCharmMessage(), localIndices );
66   traceUserBracketEvent(10000, start, CmiWallTimer());
67 }
68
69
70
71
72
73 /** 
74     Forward multicast message to our successor processors in the spanning tree. 
75     Uses CmiSyncListSendAndFree for delivery to this strategy's OneTimeMulticastStrategy::handleMessage method.
76 */
77 void OneTimeMulticastStrategy::remoteMulticast(ComlibMulticastMsg * multMsg, bool rootPE) {
78   double start = CmiWallTimer();
79
80   envelope *env = UsrToEnv(multMsg);
81     
82   
83   /// The index into the PE list in the message
84   int myIndex = -10000; 
85   const int totalDestPEs = multMsg->nPes;
86   const int myPe = CkMyPe();
87   
88   // Find my index in the list of all destination PEs
89   if(rootPE){
90     myIndex = -1;
91   } else {
92     for (int i=0; i<totalDestPEs; ++i) {
93       if(multMsg->indicesCount[i].pe == myPe){
94         myIndex = i;
95         break;
96       }
97     }
98   }
99   
100   if(myIndex == -10000)
101     CkAbort("My PE was not found in the list of destination PEs in the ComlibMulticastMsg");
102   
103   int npes;
104   int *pelist = NULL;
105
106   if(totalDestPEs > 0)
107     determineNextHopPEs(totalDestPEs, multMsg->indicesCount, myIndex, pelist, npes );
108   else {
109     npes = 0;
110   }
111
112   if(npes == 0) {
113 #if DEBUG
114     CkPrintf("[%d] OneTimeMulticastStrategy::remoteMulticast is not forwarding to any other PEs\n", CkMyPe());
115 #endif
116     traceUserBracketEvent(10001, start, CmiWallTimer());
117     CmiFree(env);
118     return;
119   }
120   
121   CmiSetHandler(env, CkpvAccess(comlib_handler));
122   ((CmiMsgHeaderBasic *) env)->stratid = getInstance();  
123   CkPackMessage(&env);
124
125   double middle = CmiWallTimer();
126
127
128   //Collect Multicast Statistics
129   RECORD_SENDM_STATS(getInstance(), env->getTotalsize(), pelist, npes);
130   
131   CkAssert(npes > 0);
132   CmiSyncListSendAndFree(npes, pelist, env->getTotalsize(), (char*)env);
133   
134   delete[] pelist;
135
136   double end = CmiWallTimer();
137   traceUserBracketEvent(10001, start, middle);
138   traceUserBracketEvent(10002, middle, end);
139   
140 }
141
142
143
144 /** 
145     Receive an incoming multicast message(sent from OneTimeMulticastStrategy::remoteMulticast).
146     Deliver the message to all local elements 
147 */
148 void OneTimeMulticastStrategy::handleMessage(void *msg){
149 #if DEBUG
150   //  CkPrintf("[%d] OneTimeMulticastStrategy::handleMessage\n", CkMyPe());
151 #endif
152   envelope *env = (envelope *)msg;
153   CkUnpackMessage(&env);
154   
155   ComlibMulticastMsg* multMsg = (ComlibMulticastMsg*)EnvToUsr(env);
156   
157   // Don't use msg after this point. Instead use the unpacked env
158   
159   RECORD_RECV_STATS(getInstance(), env->getTotalsize(), env->getSrcPe());
160   
161   // Deliver to objects marked as local in the message
162   int localElems;
163   envelope *newenv;
164   CkArrayIndexMax *local_idx_list;  
165   sinfo.unpack(env, localElems, local_idx_list, newenv);
166   ComlibMulticastMsg *newmsg = (ComlibMulticastMsg *)EnvToUsr(newenv);  
167   deliverToIndices(newmsg, localElems, local_idx_list );
168   
169   // Forward on to other processors if necessary
170   remoteMulticast(multMsg, false);
171
172 }
173
174
175
176
177 void OneTimeMulticastStrategy::determineNextHopPEs(const int totalDestPEs, const ComlibMulticastIndexCount* destPEs, const int myIndex, int * &pelist, int &npes) {
178   if(myIndex==-1){
179     // We are at a root node of the spanning tree. 
180     // We will forward the message to all other PEs in the destination list.
181     npes = totalDestPEs;
182     
183     pelist = new int[npes];
184     for (int i=0; i<npes; ++i) {
185       pelist[i] = destPEs[i].pe;
186     }
187   } else {
188     // We are at a leaf node of the spanning tree. 
189     npes = 0;
190   }
191   
192 }
193
194
195 void OneTimeRingMulticastStrategy::determineNextHopPEs(const int totalDestPEs, const ComlibMulticastIndexCount* destPEs, const int myIndex, int * &pelist, int &npes) {
196   const int myPe = CkMyPe();
197
198   if(myIndex == totalDestPEs-1){
199     // Final PE won't send to anyone
200     npes = 0;
201     return;
202   } else {
203     // All non-final PEs will send to next PE in list
204     npes = 1;
205     pelist = new int[1];
206     pelist[0] = destPEs[myIndex+1].pe;
207   }
208
209 }
210
211
212 void OneTimeTreeMulticastStrategy::determineNextHopPEs(const int totalDestPEs, const ComlibMulticastIndexCount* destPEs, const int myIndex, int * &pelist, int &npes){
213   const int myPe = CkMyPe();
214   
215   // The logical indices start at 0 = root node. Logical index i corresponds to the entry i+1 in the array of PEs in the message
216   
217   int sendLogicalIndexStart = degree*(myIndex+1) + 1;       // inclusive
218   int sendLogicalIndexEnd = sendLogicalIndexStart + degree - 1;   // inclusive
219   
220   if(sendLogicalIndexEnd-1 >= totalDestPEs){
221     sendLogicalIndexEnd = totalDestPEs;
222   }
223
224   int numSend = sendLogicalIndexEnd - sendLogicalIndexStart + 1;
225   if(numSend <= 0){
226     npes = 0;
227     return;
228   }
229  
230 #if DEBUG
231   if(numSend > 0)
232     CkPrintf("Tree logical index %d sending to logical %d to %d (totalDestPEs excluding root=%d)  numSend=%d\n",
233              myIndex+1, sendLogicalIndexStart, sendLogicalIndexEnd, totalDestPEs, numSend);
234 #endif
235
236   npes = numSend;
237   pelist = new int[npes];
238   
239   for(int i=0;i<numSend;i++){
240     CkAssert(sendLogicalIndexStart-1+i < totalDestPEs);
241     pelist[i] = destPEs[sendLogicalIndexStart-1+i].pe;
242 #if DEBUG
243     CkPrintf("Tree logical index %d sending to PE %d\n", myIndex+1, pelist[i]);
244 #endif
245     CkAssert(pelist[i] < CkNumPes());
246   }
247   
248 }
249
250
251 /** Find a unique representative PE for a node containing pe, with the restriction that the returned PE is in the list destPEs. */
252 int getFirstPeOnPhysicalNodeFromList(int pe, const int totalDestPEs, const ComlibMulticastIndexCount* destPEs){
253   int num;
254   int *nodePeList;
255   CmiGetPesOnPhysicalNode(pe, &nodePeList, &num);
256   
257   for(int i=0;i<num;i++){
258     // Scan destPEs for the pe
259     int p = nodePeList[i];
260     
261     for(int j=0;j<totalDestPEs;j++){
262       if(p == destPEs[j].pe){
263         // found the representative PE for the node that is in the destPEs list
264         return p;
265       }
266     }
267   }
268   
269   CkAbort("ERROR: Could not find an entry for pe in destPEs list.\n");
270   return -1;
271 }
272
273
274 /** Find a unique representative PE for a node containing pe, with the restriction that the returned PE is in the list destPEs. */
275 int getNthPeOnPhysicalNodeFromList(int n, int pe, const int totalDestPEs, const ComlibMulticastIndexCount* destPEs){
276   int num;
277   int *nodePeList;
278   CmiGetPesOnPhysicalNode(pe, &nodePeList, &num);
279   
280   int count = 0;
281   int lastFound = -1;
282   
283   // Foreach PE on this physical node
284   for(int i=0;i<num;i++){
285     int p = nodePeList[i];
286     
287     // Scan destPEs for the pe
288     for(int j=0;j<totalDestPEs;j++){
289       if(p == destPEs[j].pe){
290         lastFound = p;
291         if(count==n)
292           return p;
293         count++;
294       }
295     }
296   }
297   
298   if(lastFound != -1)
299     return lastFound;
300
301   CkAbort("ERROR: Could not find an entry for pe in destPEs list.\n");
302   return -1;
303 }
304
305
306
307 /** List all the other PEs from the list that share the physical node */
308 std::vector<int> getOtherPesOnPhysicalNodeFromList(int pe, const int totalDestPEs, const ComlibMulticastIndexCount* destPEs){
309   
310   std::vector<int> result;
311
312   int num;
313   int *nodePeList;
314   CmiGetPesOnPhysicalNode(pe, &nodePeList, &num);
315   
316   for(int i=0;i<num;i++){
317     // Scan destPEs for the pe
318     int p = nodePeList[i];
319     if(p != pe){
320       for(int j=0;j<totalDestPEs;j++){
321         if(p == destPEs[j].pe){
322           // found the representative PE for the node that is in the destPEs list
323           result.push_back(p);
324           break;
325         }
326       }
327     }
328   }
329   
330   return result;
331 }
332
333
334 void OneTimeNodeTreeMulticastStrategy::determineNextHopPEs(const int totalDestPEs, const ComlibMulticastIndexCount* destPEs, const int myIndex, int * &pelist, int &npes){
335   const int myPe = CkMyPe();
336
337   std::set<int> nodePERepresentatives;
338   
339   // create a list of PEs, with one for each node to which the message must be sent
340   for(int i=0; i<totalDestPEs; i++){
341     int pe = destPEs[i].pe;
342     int representative = getFirstPeOnPhysicalNodeFromList(pe, totalDestPEs, destPEs);
343     nodePERepresentatives.insert(representative);    
344   }
345   
346   int numRepresentativePEs = nodePERepresentatives.size();
347   
348   int repForMyPe=-1;
349   if(myIndex != -1)
350     repForMyPe = getFirstPeOnPhysicalNodeFromList(CkMyPe(), totalDestPEs, destPEs);
351   
352 #if DEBUG
353   CkPrintf("[%d] Multicasting to %d PEs on %d physical nodes  repForMyPe=%d\n", CkMyPe(), totalDestPEs, numRepresentativePEs, repForMyPe);
354   fflush(stdout);
355 #endif
356   
357   // If this PE is part of the multicast tree, then it should forward the message along
358   if(CkMyPe() == repForMyPe || myIndex == -1){
359     // I am an internal node in the multicast tree
360     
361     // flatten the data structure for nodePERepresentatives
362     int *repPeList = new int[numRepresentativePEs];
363     int myRepIndex = -1;
364     std::set<int>::iterator iter;
365     int p=0;
366     for(iter=nodePERepresentatives.begin(); iter != nodePERepresentatives.end(); iter++){
367       repPeList[p] = *iter;
368       if(*iter == repForMyPe)
369         myRepIndex = p;
370       p++;
371     }
372     CkAssert(myRepIndex >=0 || myIndex==-1);
373       
374     // The logical indices start at 0 = root node. Logical index i corresponds to the entry i+1 in the array of PEs in the message
375     int sendLogicalIndexStart = degree*(myRepIndex+1) + 1;       // inclusive
376     int sendLogicalIndexEnd = sendLogicalIndexStart + degree - 1;   // inclusive
377     
378     if(sendLogicalIndexEnd-1 >= numRepresentativePEs){
379       sendLogicalIndexEnd = numRepresentativePEs;
380     }
381     
382     int numSendTree = sendLogicalIndexEnd - sendLogicalIndexStart + 1;
383     if(numSendTree < 0)
384       numSendTree = 0;
385     
386     std::vector<int> otherLocalPes = getOtherPesOnPhysicalNodeFromList(CkMyPe(), totalDestPEs, destPEs);
387     int numSendLocal;
388     if(myIndex == -1)
389       numSendLocal = 0;
390     else 
391       numSendLocal = otherLocalPes.size();
392     
393     
394
395 #if DEBUG
396     CkPrintf("[%d] numSendTree=%d numSendLocal=%d sendLogicalIndexStart=%d sendLogicalIndexEnd=%d\n", CkMyPe(), numSendTree, numSendLocal,  sendLogicalIndexStart, sendLogicalIndexEnd);
397     fflush(stdout);
398 #endif
399
400     int numSend = numSendTree + numSendLocal;
401     if(numSend <= 0){
402       npes = 0;
403       return;
404     }
405     
406     npes = numSend;
407     pelist = new int[npes];
408   
409     for(int i=0;i<numSendTree;i++){
410       CkAssert(sendLogicalIndexStart-1+i < numRepresentativePEs);
411       pelist[i] = repPeList[sendLogicalIndexStart-1+i];
412       CkAssert(pelist[i] < CkNumPes() && pelist[i] >= 0);
413     }
414     
415     delete[] repPeList;
416     repPeList = NULL;
417
418     for(int i=0;i<numSendLocal;i++){
419       pelist[i+numSendTree] = otherLocalPes[i];
420       CkAssert(pelist[i] < CkNumPes() && pelist[i] >= 0);
421     }
422     
423     
424 #if DEBUG
425     char buf[1024];
426     sprintf(buf, "PE %d is sending to Remote Node PEs: ", CkMyPe() );
427     for(int i=0;i<numSend;i++){
428       if(i==numSendTree)
429         sprintf(buf+strlen(buf), " and Local To Node PEs: ", pelist[i]);
430
431       sprintf(buf+strlen(buf), "%d ", pelist[i]);
432     }    
433     CkPrintf("%s\n", buf);
434     fflush(stdout);
435 #endif
436         
437   } else {
438     // We are a leaf PE
439     npes = 0;
440     return;
441   }
442
443   
444   
445 }
446
447
448 void OneTimeNodeTreeRingMulticastStrategy::determineNextHopPEs(const int totalDestPEs, const ComlibMulticastIndexCount* destPEs, const int myIndex, int * &pelist, int &npes){
449   const int myPe = CkMyPe();
450
451   std::set<int> nodePERepresentatives;
452   
453   // create a list of PEs, with one for each node to which the message must be sent
454   for(int i=0; i<totalDestPEs; i++){
455     int pe = destPEs[i].pe;
456     int representative = getFirstPeOnPhysicalNodeFromList(pe, totalDestPEs, destPEs);
457     nodePERepresentatives.insert(representative);    
458   }
459   
460   int numRepresentativePEs = nodePERepresentatives.size();
461   
462   int repForMyPe=-1;
463   if(myIndex != -1)
464     repForMyPe = getFirstPeOnPhysicalNodeFromList(CkMyPe(), totalDestPEs, destPEs);
465   
466 #if DEBUG
467   CkPrintf("[%d] Multicasting to %d PEs on %d physical nodes  repForMyPe=%d\n", CkMyPe(), totalDestPEs, numRepresentativePEs, repForMyPe);
468   fflush(stdout);
469 #endif
470   
471   // If this PE is part of the multicast tree, then it should forward the message along
472   if(CkMyPe() == repForMyPe || myIndex == -1){
473     // I am an internal node in the multicast tree
474     
475     // flatten the data structure for nodePERepresentatives
476     int *repPeList = new int[numRepresentativePEs];
477     int myRepIndex = -1;
478     std::set<int>::iterator iter;
479     int p=0;
480     for(iter=nodePERepresentatives.begin(); iter != nodePERepresentatives.end(); iter++){
481       repPeList[p] = *iter;
482       if(*iter == repForMyPe)
483         myRepIndex = p;
484       p++;
485     }
486     CkAssert(myRepIndex >=0 || myIndex==-1);
487       
488     // The logical indices start at 0 = root node. Logical index i corresponds to the entry i+1 in the array of PEs in the message
489     int sendLogicalIndexStart = degree*(myRepIndex+1) + 1;       // inclusive
490     int sendLogicalIndexEnd = sendLogicalIndexStart + degree - 1;   // inclusive
491     
492     if(sendLogicalIndexEnd-1 >= numRepresentativePEs){
493       sendLogicalIndexEnd = numRepresentativePEs;
494     }
495     
496     int numSendTree = sendLogicalIndexEnd - sendLogicalIndexStart + 1;
497     if(numSendTree < 0)
498       numSendTree = 0;
499
500
501     // Send in a ring to the PEs on this node
502     std::vector<int> otherLocalPes = getOtherPesOnPhysicalNodeFromList(CkMyPe(), totalDestPEs, destPEs);
503     int numSendLocal = 0;
504     if(myIndex == -1)
505       numSendLocal = 0;
506     else {
507       if(otherLocalPes.size() > 0)
508         numSendLocal = 1;
509       else
510         numSendLocal = 0;
511     }
512     
513
514 #if DEBUG
515     CkPrintf("[%d] numSendTree=%d numSendLocal=%d sendLogicalIndexStart=%d sendLogicalIndexEnd=%d\n", CkMyPe(), numSendTree, numSendLocal,  sendLogicalIndexStart, sendLogicalIndexEnd);
516     fflush(stdout);
517 #endif
518
519     int numSend = numSendTree + numSendLocal;
520     if(numSend <= 0){
521       npes = 0;
522       return;
523     }
524     
525     npes = numSend;
526     pelist = new int[npes];
527   
528     for(int i=0;i<numSendTree;i++){
529       CkAssert(sendLogicalIndexStart-1+i < numRepresentativePEs);
530       pelist[i] = repPeList[sendLogicalIndexStart-1+i];
531       CkAssert(pelist[i] < CkNumPes() && pelist[i] >= 0);
532     }
533     
534     delete[] repPeList;
535     repPeList = NULL;
536
537     for(int i=0;i<numSendLocal;i++){
538       pelist[i+numSendTree] = otherLocalPes[i];
539       CkAssert(pelist[i] < CkNumPes() && pelist[i] >= 0);
540     }
541     
542     
543 #if DEBUG
544     char buf[1024];
545     sprintf(buf, "PE %d is sending to Remote Node PEs: ", CkMyPe() );
546     for(int i=0;i<numSend;i++){
547       if(i==numSendTree)
548         sprintf(buf+strlen(buf), " and Local To Node PEs: ", pelist[i]);
549
550       sprintf(buf+strlen(buf), "%d ", pelist[i]);
551     }    
552     CkPrintf("%s\n", buf);
553     fflush(stdout);
554 #endif
555         
556   } else {
557     // We are a leaf PE, so forward in a ring to the PEs on this node
558     const std::vector<int> otherLocalPes = getOtherPesOnPhysicalNodeFromList(CkMyPe(), totalDestPEs, destPEs);
559     
560     npes = 0;
561     pelist = new int[1];
562     
563     for(int i=0;i<otherLocalPes.size();i++){
564       if(otherLocalPes[i] == CkMyPe()){
565         // found me in the PE list for this node
566         if(i+1<otherLocalPes.size()){
567           // If we have a successor in the ring
568           pelist[0] = otherLocalPes[i+1];
569           npes = 1;
570         }
571       }
572     }
573      
574
575 #if 1
576     if(npes==0)
577       CkPrintf("[%d] At end of ring", CkMyPe() );
578     else
579       CkPrintf("[%d] sending along ring to %d\n", CkMyPe(), pelist[0] );
580     
581     fflush(stdout);
582 #endif
583
584
585   }
586
587   
588   
589 }
590
591 /*@}*/