Adding fix for AMPI broadcast strategy.
[charm.git] / src / ck-com / RingMulticastStrategy.C
1 #include "RingMulticastStrategy.h"
2
3 //Group Constructor
4 RingMulticastStrategy::RingMulticastStrategy(int ndest, int *pelist) 
5     : DirectMulticastStrategy(ndest, pelist) {
6     commonRingInit();
7 }
8
9 //Array Constructor
10 RingMulticastStrategy::RingMulticastStrategy(CkArrayID dest_aid)
11     : DirectMulticastStrategy(dest_aid){
12     commonRingInit();    
13 }
14
15 //Array Constructor
16 RingMulticastStrategy::RingMulticastStrategy(CkArrayID src, CkArrayID dest)
17     : DirectMulticastStrategy(src, dest){
18     commonRingInit();    
19 }
20
21 void RingMulticastStrategy::commonRingInit(){
22     //Sort destpelist
23 }
24
25
26 void RingMulticastStrategy::insertMessage(CharmMessageHolder *cmsg){
27     if(messageBuf == NULL) {
28         CkPrintf("ERROR MESSAGE BUF IS NULL\n");
29         return;
30     }
31     
32     ComlibPrintf("[%d] Comlib Direct Multicast: insertMessage \n", 
33                  CkMyPe());   
34     
35     if(cmsg->dest_proc == IS_BROADCAST) {
36         void *m = cmsg->getCharmMessage();
37         CkSectionInfo minfo;
38         minfo.type = COMLIB_MULTICAST_MESSAGE;
39         minfo.sInfo.cInfo.instId = getInstance();
40         minfo.sInfo.cInfo.status = COMLIB_MULTICAST_ALL;  
41         minfo.sInfo.cInfo.id = 0; 
42         minfo.pe = CkMyPe();
43         ((CkMcastBaseMsg *)m)->_cookie = minfo;       
44     }
45
46     if(cmsg->dest_proc == IS_SECTION_MULTICAST && cmsg->sec_id != NULL) { 
47         int cur_sec_id = ComlibSectionInfo::getSectionID(*cmsg->sec_id);
48
49         if(cur_sec_id > 0) {        
50             sinfo.processOldSectionMessage(cmsg);
51         }
52         else {
53             CkSectionID *sid = cmsg->sec_id;
54
55             //New sec id, so send it along with the message
56             void *newmsg = sinfo.getNewMulticastMessage(cmsg);
57             CkFreeMsg(cmsg->getCharmMessage());
58             delete cmsg;
59             
60             initSectionID(sid);
61             cmsg = new CharmMessageHolder((char *)newmsg, 
62                                           IS_SECTION_MULTICAST); 
63             cmsg->sec_id = sid;
64         }        
65     }
66     
67     messageBuf->enq(cmsg);
68     if(!isBracketed())
69         doneInserting();
70 }
71
72 extern int _charmHandlerIdx;
73 void RingMulticastStrategy::doneInserting(){
74     ComlibPrintf("%d: DoneInserting \n", CkMyPe());
75     
76     if(messageBuf->length() == 0) {
77         return;
78     }
79
80     while(!messageBuf->isEmpty()) {
81         CharmMessageHolder *cmsg = messageBuf->deq();
82         char *msg = cmsg->getCharmMessage();
83         register envelope* env = UsrToEnv(msg);
84
85         ComlibPrintf("[%d] Calling Ring %d %d %d\n", CkMyPe(),
86                      env->getTotalsize(), ndestpes, cmsg->dest_proc);
87                 
88         if(cmsg->dest_proc == IS_SECTION_MULTICAST ||
89            cmsg->dest_proc == IS_BROADCAST) {      
90             
91             CmiSetHandler(env, handlerId);
92             
93             int dest_pe = -1;
94             RingMulticastHashObject *robj;
95             
96             if(cmsg->sec_id == NULL)
97                 dest_pe = nextPE;
98             else {
99                 robj = getHashObject(CkMyPe(), 
100                                      cmsg->sec_id->_cookie.sInfo.cInfo.id);
101                 
102                 ComlibPrintf("Gotten has obect %d\n",  robj);                
103                 CkAssert(robj != NULL);                
104                 dest_pe = robj->nextPE;
105             }
106             
107             ComlibPrintf("[%d] Sending Message to %d\n", CkMyPe(), dest_pe);
108
109             if(dest_pe != -1)
110                 CmiSyncSend(dest_pe, env->getTotalsize(), (char *)env); 
111             
112             if(getType() == ARRAY_STRATEGY) {
113                 CmiSyncSendAndFree(CkMyPe(), env->getTotalsize(), (char *)env);
114             }
115             else {
116                 CmiSetHandler(env, _charmHandlerIdx);
117                 CmiSyncSendAndFree(CkMyPe(), env->getTotalsize(), (char *)env);
118             }
119         }
120         else {
121             CmiSyncSendAndFree(cmsg->dest_proc, UsrToEnv(msg)->getTotalsize(), 
122                                (char *)UsrToEnv(msg));
123         }        
124         
125         delete cmsg; 
126     }
127 }
128
129 void RingMulticastStrategy::pup(PUP::er &p){
130
131     DirectMulticastStrategy::pup(p);
132 }
133
134 void RingMulticastStrategy::beginProcessing(int  nelements){
135
136     DirectMulticastStrategy::beginProcessing(nelements);
137
138     nextPE = -1;
139     if(ndestpes == 1)
140         return;
141
142     for(int count = 0; count < ndestpes; count++)
143         if(destpelist[count] > CkMyPe()) {
144             nextPE = destpelist[count];
145             break;
146         }
147     if(nextPE == -1)
148         nextPE = destpelist[0];
149 }
150
151 void RingMulticastStrategy::handleMulticastMessage(void *msg){
152     register envelope *env = (envelope *)msg;
153        
154     CkMcastBaseMsg *cbmsg = (CkMcastBaseMsg *)EnvToUsr(env);
155     int src_pe = cbmsg->_cookie.pe;
156     if(getType() == GROUP_STRATEGY){               
157
158         if(!isEndOfRing(nextPE, src_pe)) {
159             ComlibPrintf("[%d] Forwarding Message to %d\n", CkMyPe(), nextPE);
160             CmiSyncSend(nextPE, env->getTotalsize(), (char *)env);        
161         }
162         CmiSetHandler(env, _charmHandlerIdx);
163         CmiSyncSendAndFree(CkMyPe(), env->getTotalsize(), (char *)env);
164         
165         return;
166     }
167
168     int status = cbmsg->_cookie.sInfo.cInfo.status;
169     ComlibPrintf("[%d] In handle multicast message %d\n", CkMyPe(), status);
170
171     if(status == COMLIB_MULTICAST_ALL) {                        
172         if(src_pe != CkMyPe() && !isEndOfRing(nextPE, src_pe)) {
173             ComlibPrintf("[%d] Forwarding Message to %d\n", CkMyPe(), nextPE);
174             CmiSyncSend(nextPE, env->getTotalsize(), (char *)env); 
175         }
176
177         ainfo.localBroadcast(env);
178     }   
179     else if(status == COMLIB_MULTICAST_NEW_SECTION){        
180         CkUnpackMessage(&env);
181         ComlibPrintf("[%d] Received message for new section src=%d\n", 
182                      CkMyPe(), cbmsg->_cookie.pe);
183
184         ComlibMulticastMsg *ccmsg = (ComlibMulticastMsg *)cbmsg;
185         
186         RingMulticastHashObject *robj = 
187             createHashObject(ccmsg->nIndices, ccmsg->indices);
188         
189         envelope *usrenv = (envelope *) ccmsg->usrMsg;
190         
191         envelope *newenv = (envelope *)CmiAlloc(usrenv->getTotalsize());
192         memcpy(newenv, usrenv, usrenv->getTotalsize());
193
194         ComlibArrayInfo::localMulticast(&robj->indices, newenv);
195
196         ComlibSectionHashKey key(cbmsg->_cookie.pe, 
197                                  cbmsg->_cookie.sInfo.cInfo.id);
198
199         RingMulticastHashObject *old_robj = 
200             (RingMulticastHashObject*)sec_ht.get(key);
201         if(old_robj != NULL)
202             delete old_robj;
203         
204         sec_ht.put(key) = robj;
205
206         if(src_pe != CkMyPe() && !isEndOfRing(robj->nextPE, src_pe)) {
207             ComlibPrintf("[%d] Forwarding Message of %d to %d\n", CkMyPe(), 
208                          cbmsg->_cookie.pe, robj->nextPE);
209             CkPackMessage(&env);
210             CmiSyncSendAndFree(robj->nextPE, env->getTotalsize(), 
211                                (char *)env);
212         }
213         else
214             CmiFree(env);       
215     }
216     else {
217         //status == COMLIB_MULTICAST_OLD_SECTION, use the cached section id
218         ComlibSectionHashKey key(cbmsg->_cookie.pe, 
219                                  cbmsg->_cookie.sInfo.cInfo.id);    
220         RingMulticastHashObject *robj = (RingMulticastHashObject *)sec_ht.
221             get(key);
222         
223         if(robj == NULL)
224             CkAbort("Destination indices is NULL\n");
225         
226         if(src_pe != CkMyPe() && !isEndOfRing(robj->nextPE, src_pe)) {
227             CmiSyncSend(robj->nextPE, env->getTotalsize(), (char *)env);
228             ComlibPrintf("[%d] Forwarding Message to %d\n", CkMyPe(), 
229                          robj->nextPE);
230         }
231         
232         ComlibArrayInfo::localMulticast(&robj->indices, env);
233     }
234 }
235
236 void RingMulticastStrategy::initSectionID(CkSectionID *sid){
237
238     ComlibPrintf("Ring Init section ID\n");
239     sid->pelist = NULL;
240     sid->npes = 0;
241
242     RingMulticastHashObject *robj = 
243         createHashObject(sid->_nElems, sid->_elems);
244     
245     ComlibSectionHashKey key(CkMyPe(), sid->_cookie.sInfo.cInfo.id);
246     sec_ht.put(key) = robj;
247 }
248
249 RingMulticastHashObject *RingMulticastStrategy::createHashObject
250 (int nelements, CkArrayIndexMax *elements){
251
252     RingMulticastHashObject *robj = new RingMulticastHashObject;
253
254     int next_pe = CkNumPes();
255     int acount = 0;
256     int min_dest = CkNumPes();
257     for(acount = 0; acount < nelements; acount++){
258         //elements[acount].print();
259         
260         CkArrayID dest;
261         int nidx;
262         CkArrayIndexMax *idx_list;        
263         ainfo.getDestinationArray(dest, idx_list, nidx);
264
265         int p = ComlibGetLastKnown(dest, elements[acount]);
266         //CkArrayID::CkLocalBranch(dest)->lastKnown(elements[acount]);
267         
268         if(p < min_dest)
269             min_dest = p;
270         
271         if(p > CkMyPe() && next_pe > p) 
272             next_pe = p;       
273
274         if (p == CkMyPe())
275             robj->indices.insertAtEnd(elements[acount]);
276     }
277     
278     //Recycle the destination pelist and start from the begining
279     if(next_pe == CkNumPes() && min_dest != CkMyPe())        
280         next_pe = min_dest;
281     
282     if(next_pe == CkNumPes())
283         next_pe = -1;
284
285     robj->nextPE = next_pe;
286
287     return robj;
288 }
289
290
291 RingMulticastHashObject *RingMulticastStrategy::getHashObject(int pe, int id){
292     
293     ComlibSectionHashKey key(pe, id);
294     RingMulticastHashObject *robj = (RingMulticastHashObject *)sec_ht.get(key);
295     return robj;
296 }
297
298 int RingMulticastStrategy::isEndOfRing(int next_pe, int src_pe){
299
300     if(next_pe < 0)
301         return 1;
302
303     ComlibPrintf("[%d] isEndofring %d, %d\n", CkMyPe(), next_pe, src_pe);
304
305     if(next_pe > CkMyPe()){
306         if(src_pe <= next_pe && src_pe > CkMyPe())
307             return 1;
308
309         return 0;
310     }
311     
312     //next_pe < CkMyPe()
313
314     if(src_pe > CkMyPe() || src_pe <= next_pe)
315         return 1;
316     
317     return 0;
318 }