Adding fix for AMPI broadcast strategy.
[charm.git] / src / ck-com / DirectMulticastStrategy.C
1 #include "DirectMulticastStrategy.h"
2 #include "AAMLearner.h"
3
4 CkpvExtern(CkGroupID, cmgrID);
5
6 void *DMHandler(void *msg){
7     ComlibPrintf("[%d]:In CallbackHandler\n", CkMyPe());
8     DirectMulticastStrategy *nm_mgr;    
9     
10     CkMcastBaseMsg *bmsg = (CkMcastBaseMsg *)EnvToUsr((envelope *)msg);
11     int instid = bmsg->_cookie.sInfo.cInfo.instId;
12     
13     nm_mgr = (DirectMulticastStrategy *) 
14         CProxy_ComlibManager(CkpvAccess(cmgrID)).
15         ckLocalBranch()->getStrategy(instid);
16
17     envelope *env = (envelope *) msg;
18     RECORD_RECV_STATS(instid, env->getTotalsize(), env->getSrcPe());
19     nm_mgr->handleMulticastMessage(msg);
20     return NULL;
21 }
22
23 //Group Constructor
24 DirectMulticastStrategy::DirectMulticastStrategy(int ndest, int *pelist)
25     : CharmStrategy() {
26  
27     setType(GROUP_STRATEGY);
28     
29     ndestpes = ndest;
30     destpelist = pelist;
31
32     commonInit();
33 }
34
35 DirectMulticastStrategy::DirectMulticastStrategy(CkArrayID aid)
36     :  CharmStrategy() {
37
38     //ainfo.setSourceArray(aid);
39     ainfo.setDestinationArray(aid);
40     setType(ARRAY_STRATEGY);
41     ndestpes = 0;
42     destpelist = 0;
43     commonInit();
44 }
45
46 DirectMulticastStrategy::DirectMulticastStrategy(CkArrayID said, CkArrayID daid)
47     :  CharmStrategy() {
48
49     ainfo.setSourceArray(said);
50     ainfo.setDestinationArray(daid);
51     setType(ARRAY_STRATEGY);
52     ndestpes = 0;
53     destpelist = 0;
54     commonInit();
55 }
56
57 void DirectMulticastStrategy::commonInit(){
58
59     if(ndestpes == 0) {
60         ndestpes = CkNumPes();
61         destpelist = new int[CkNumPes()];
62         for(int count = 0; count < CkNumPes(); count ++)
63             destpelist[count] = count;        
64     }
65 }
66
67 DirectMulticastStrategy::~DirectMulticastStrategy() {
68     if(ndestpes > 0)
69         delete [] destpelist;
70
71     if(getLearner() != NULL)
72         delete getLearner();
73         
74     CkHashtableIterator *ht_iterator = sec_ht.iterator();
75     ht_iterator->seekStart();
76     while(ht_iterator->hasNext()){
77         void **data;
78         data = (void **)ht_iterator->next();        
79         CkVec<CkArrayIndexMax> *a_vec = (CkVec<CkArrayIndexMax> *) (* data);
80         if(a_vec != NULL)
81             delete a_vec;
82     }
83 }
84
85 void DirectMulticastStrategy::insertMessage(CharmMessageHolder *cmsg){
86     if(messageBuf == NULL) {
87         CkPrintf("ERROR MESSAGE BUF IS NULL\n");
88         return;
89     }
90
91     ComlibPrintf("[%d] Comlib Direct Multicast: insertMessage \n", 
92                  CkMyPe());   
93    
94     if(cmsg->dest_proc == IS_BROADCAST) {
95         void *m = cmsg->getCharmMessage();
96         CkSectionInfo minfo;
97         minfo.type = COMLIB_MULTICAST_MESSAGE;
98         minfo.sInfo.cInfo.instId = getInstance();
99         minfo.sInfo.cInfo.status = COMLIB_MULTICAST_ALL;  
100         minfo.sInfo.cInfo.id = 0; 
101         minfo.pe = CkMyPe();
102         ((CkMcastBaseMsg *)m)->_cookie = minfo;       
103     }
104
105     if(cmsg->dest_proc == IS_SECTION_MULTICAST && cmsg->sec_id != NULL) { 
106         int cur_sec_id = ComlibSectionInfo::getSectionID(*cmsg->sec_id);
107
108         if(cur_sec_id > 0) {        
109             sinfo.processOldSectionMessage(cmsg);
110         }
111         else {
112             CkSectionID *sid = cmsg->sec_id;
113
114             //New sec id, so send it along with the message
115             void *newmsg = sinfo.getNewMulticastMessage(cmsg);
116             CkFreeMsg(cmsg->getCharmMessage());
117             delete cmsg;
118             
119             sinfo.initSectionID(sid);
120
121             cmsg = new CharmMessageHolder((char *)newmsg, 
122                                           IS_SECTION_MULTICAST); 
123             cmsg->sec_id = sid;
124         }        
125     }
126    
127     messageBuf->enq(cmsg);
128     if(!isBracketed())
129         doneInserting();
130 }
131
132 void DirectMulticastStrategy::doneInserting(){
133     ComlibPrintf("%d: DoneInserting \n", CkMyPe());
134     
135     if(messageBuf->length() == 0) {
136         return;
137     }
138
139     while(!messageBuf->isEmpty()) {
140         CharmMessageHolder *cmsg = messageBuf->deq();
141         char *msg = cmsg->getCharmMessage();
142                 
143         if(cmsg->dest_proc == IS_SECTION_MULTICAST || 
144            cmsg->dest_proc == IS_BROADCAST) {      
145
146             if(getType() == ARRAY_STRATEGY)
147                 CmiSetHandler(UsrToEnv(msg), handlerId);
148             
149             int *cur_map = destpelist;
150             int cur_npes = ndestpes;
151             if(cmsg->sec_id != NULL && cmsg->sec_id->pelist != NULL) {
152                 cur_map = cmsg->sec_id->pelist;
153                 cur_npes = cmsg->sec_id->npes;
154             }
155             
156             //Collect Multicast Statistics
157             RECORD_SENDM_STATS(getInstance(), 
158                                ((envelope *)cmsg->getMessage())->getTotalsize(), 
159                                cur_map, cur_npes);
160
161
162             ComlibPrintf("[%d] Calling Direct Multicast %d %d %d\n", CkMyPe(),
163                          UsrToEnv(msg)->getTotalsize(), cur_npes, 
164                          cmsg->dest_proc);
165
166             /*
167               for(int i=0; i < cur_npes; i++)
168               CkPrintf("[%d] Sending to %d %d\n", CkMyPe(), 
169               cur_map[i], cur_npes);
170             */
171
172             CmiSyncListSendAndFree(cur_npes, cur_map, 
173                                    UsrToEnv(msg)->getTotalsize(), 
174                                    (char*)(UsrToEnv(msg)));            
175         }
176         else {
177             //CkPrintf("SHOULD NOT BE HERE\n");
178             CmiSyncSendAndFree(cmsg->dest_proc, 
179                                UsrToEnv(msg)->getTotalsize(), 
180                                (char *)UsrToEnv(msg));
181         }        
182         
183         delete cmsg; 
184     }
185 }
186
187 void DirectMulticastStrategy::pup(PUP::er &p){
188
189     CharmStrategy::pup(p);
190
191     p | ndestpes;
192     if(p.isUnpacking() && ndestpes > 0)
193         destpelist = new int[ndestpes];
194     
195     p(destpelist, ndestpes);        
196     
197     if(p.isUnpacking()) {
198         CkArrayID src;
199         int nidx;
200         CkArrayIndexMax *idx_list;     
201         ainfo.getSourceArray(src, idx_list, nidx);
202         
203         if(!src.isZero()) {
204             AAMLearner *l = new AAMLearner();
205             setLearner(l);
206         }
207     }
208 }
209
210 void DirectMulticastStrategy::beginProcessing(int numElements){
211     
212     messageBuf = new CkQ<CharmMessageHolder *>;    
213     handlerId = CkRegisterHandler((CmiHandler)DMHandler);    
214     
215     CkArrayID dest;
216     int nidx;
217     CkArrayIndexMax *idx_list;
218
219     ainfo.getDestinationArray(dest, idx_list, nidx);
220     sinfo = ComlibSectionInfo(dest, myInstanceID);
221 }
222
223 void DirectMulticastStrategy::handleMulticastMessage(void *msg){
224     register envelope *env = (envelope *)msg;
225     
226     CkMcastBaseMsg *cbmsg = (CkMcastBaseMsg *)EnvToUsr(env);
227
228     int status = cbmsg->_cookie.sInfo.cInfo.status;
229     ComlibPrintf("[%d] In local multicast %d\n", CkMyPe(), status);
230     
231     CkVec<CkArrayIndexMax> *dest_indices; 
232     if(status == COMLIB_MULTICAST_ALL) {        
233         ainfo.localBroadcast(env);
234     }   
235     else if(status == COMLIB_MULTICAST_NEW_SECTION){        
236         CkUnpackMessage(&env);
237         envelope *newenv;
238         sinfo.unpack(env, dest_indices, newenv);
239         ComlibArrayInfo::localMulticast(dest_indices, newenv);
240
241         CkVec<CkArrayIndexMax> *old_dest_indices;
242         ComlibSectionHashKey key(cbmsg->_cookie.pe, 
243                                  cbmsg->_cookie.sInfo.cInfo.id);
244
245         old_dest_indices = (CkVec<CkArrayIndexMax> *)sec_ht.get(key);
246         if(old_dest_indices != NULL)
247             delete old_dest_indices;
248         
249         sec_ht.put(key) = dest_indices;
250         CmiFree(env);                
251     }
252     else {
253         //status == COMLIB_MULTICAST_OLD_SECTION, use the cached section id
254         ComlibSectionHashKey key(cbmsg->_cookie.pe, 
255                                  cbmsg->_cookie.sInfo.cInfo.id);    
256         dest_indices = (CkVec<CkArrayIndexMax> *)sec_ht.get(key);
257         
258         if(dest_indices == NULL)
259             CkAbort("Destination indices is NULL\n");
260         
261         ComlibArrayInfo::localMulticast(dest_indices, env);
262     }
263 }