Fixed a memory leak and a added some other code optimizations.
[charm.git] / src / conv-com / gridrouter.C
1 /*****************************************************************************
2  * $Source$
3  * $Author$
4  * $Date$
5  * $Revision$
6  *****************************************************************************/
7
8 /************************************************************
9  * File : gridrouter.C
10  *
11  * Author : Krishnan V.
12  *
13  * Grid (mesh) based router
14  ***********************************************************/
15
16 #include "gridrouter.h"
17
18 #define gmap(pe) {if (gpes) pe=gpes[pe];}
19
20 /**The only communication op used. Modify this to use
21  ** vector send */
22 #define GRIDSENDFN(kid, u1, u2, knpe, kpelist, khndl, knextpe)  \
23         {int len;\
24         char *newmsg;\
25         newmsg=PeMesh->ExtractAndPack(kid, u1, knpe, kpelist, &len);\
26         if (newmsg) {\
27           CmiSetHandler(newmsg, khndl);\
28           CmiSyncSendAndFree(knextpe, len, newmsg);\
29         }\
30         else {\
31           SendDummyMsg(kid, knextpe, u2);\
32         }\
33 }
34
35 /****************************************************
36  * Preallocated memory=P ints + MAXNUMMSGS msgstructs
37  *****************************************************/
38 GridRouter::GridRouter(int n, int me)
39 {
40   //CmiPrintf("PE=%d me=%d NUMPES=%d\n", MyPe, me, n);
41   
42   NumPes=n;
43   MyPe=me;
44   gpes=NULL;
45   COLLEN=ColLen(NumPes);
46   LPMsgExpected = Expect(MyPe, NumPes);
47   recvExpected = 0;
48
49   myrow=MyPe/COLLEN;
50   mycol=MyPe%COLLEN;  
51   int lastrow = (NumPes - 1)/COLLEN;
52   
53   if(myrow < lastrow) 
54       recvExpected = ROWLEN;
55   else
56       recvExpected = (NumPes - 1)%ROWLEN + 1;
57
58   if(lastrow * COLLEN + mycol > NumPes - 1) {
59       //We have a hole in the lastrow
60       if(lastrow * COLLEN + myrow <= NumPes - 1) 
61           //We have a processor which wants to send data to that hole
62           recvExpected ++;
63       
64       if((myrow == 0) && (NumPes == ROWLEN*(COLLEN-1) - 1))
65           //Special case with one hole only
66           recvExpected ++;      
67   }
68   
69   ComlibPrintf("%d LPMsgExpected=%d\n", MyPe, LPMsgExpected);
70
71   PeMesh = new PeTable(NumPes);
72   PeMesh1 = new PeTable(NumPes);
73   PeMesh2 = new PeTable(NumPes);
74
75   onerow=(int *)CmiAlloc(ROWLEN*sizeof(int));
76
77   rowVector = (int *)CmiAlloc(ROWLEN*sizeof(int));
78   colVector = (int *)CmiAlloc(COLLEN*sizeof(int));
79
80   int myrep=myrow*COLLEN;
81   int count = 0;
82   int pos = 0;
83
84   for(count = myrow; count < ROWLEN+myrow; count ++){
85       int nextpe= myrep + count%ROWLEN;
86       
87       if (nextpe >= NumPes) {
88           int new_row = mycol % (myrow+1);
89           
90           if(new_row >= myrow)
91               new_row = 0;
92           
93           nextpe = COLLEN * new_row + count;
94       }
95       
96       if(nextpe == MyPe)
97           continue;
98
99       rowVector[pos ++] = nextpe;
100   }
101   rvecSize = pos;
102
103   pos = 0;
104   for(count = mycol; count < COLLEN+mycol; count ++){
105       int nextrowrep = (count % COLLEN) *COLLEN;
106       int nextpe = nextrowrep+mycol;
107       
108       if(nextpe < NumPes && nextpe != MyPe)
109           colVector[pos ++] = nextpe;
110   }
111   
112   cvecSize = pos;
113
114   growVector = new int[rvecSize];
115   gcolVector = new int[cvecSize];
116
117   for(count = 0; count < rvecSize; count ++)
118       growVector[count] = rowVector[count];
119   
120   for(count = 0; count < cvecSize; count ++)
121       gcolVector[count] = colVector[count];
122   
123
124   InitVars();
125   ComlibPrintf("%d:COLLEN=%d, ROWLEN=%d, recvexpected=%d\n", MyPe, COLLEN, ROWLEN, recvExpected);
126 }
127
128 GridRouter::~GridRouter()
129 {
130   delete PeMesh;
131   delete PeMesh1;
132   delete PeMesh2;
133     
134   CmiFree(onerow);
135 }
136
137 void GridRouter :: InitVars()
138 {
139   recvCount=0;
140   LPMsgCount=0;
141 }
142
143 void GridRouter::NumDeposits(comID, int num)
144 {
145 }
146
147 void GridRouter::EachToAllMulticast(comID id, int size, void *msg, int more)
148 {
149   int npe=NumPes;
150   int * destpes=(int *)CmiAlloc(sizeof(int)*npe);
151   for (int i=0;i<npe;i++) destpes[i]=i;
152   EachToManyMulticast(id, size, msg, npe, destpes, more);
153 }
154
155 extern void CmiReference(void *blk);
156
157 void GridRouter::EachToManyMulticast(comID id, int size, void *msg, int numpes, int *destpes, int more)
158 {
159   int i=0;
160   
161   if(id.isAllToAll)
162       PeMesh->InsertMsgs(1, &MyPe, size, msg);
163   else
164       PeMesh->InsertMsgs(numpes, destpes, size, msg);
165   
166   if (more) return;
167
168   ComlibPrintf("All messages received %d %d %d\n", MyPe, COLLEN,id.isAllToAll);
169
170   char *a2amsg = NULL;
171   int a2a_len;
172   if(id.isAllToAll) {
173       ComlibPrintf("ALL to ALL flag set\n");
174
175       a2amsg = PeMesh->ExtractAndPackAll(id, 0, &a2a_len);
176       CmiSetHandler(a2amsg, CkpvAccess(RecvHandle));
177       CmiReference(a2amsg);
178       CmiSyncListSendAndFree(rvecSize, growVector, a2a_len, a2amsg);      
179       RecvManyMsg(id, a2amsg);
180       return;
181   }
182
183   //Send the messages
184   //int MYROW  =MyPe/COLLEN;
185   //int MYCOL = MyPe%COLLEN;
186   int myrep= myrow*COLLEN; 
187   int length = (NumPes - 1)/COLLEN + 1;
188  
189   for (int colcount = 0; colcount < rvecSize; ++colcount) {
190       int nextpe = rowVector[colcount];
191       i = nextpe % COLLEN;
192       
193       if((length - 1)* COLLEN + i >= NumPes)
194           length --;
195       
196       for (int j = 0; j < length; j++) {
197           onerow[j]=j * COLLEN + i;
198       }
199       
200       gmap(nextpe);
201
202       ComlibPrintf("%d: before gmap sending to %d of column %d\n",
203                    MyPe, nextpe, i);
204
205       ComlibPrintf("%d:sending to %d of column %d\n", MyPe, nextpe, i);
206       
207       GRIDSENDFN(id, 0, 0, length, onerow, CkpvAccess(RecvHandle), nextpe); 
208   }
209   RecvManyMsg(id, NULL);
210 }
211
212 void GridRouter::RecvManyMsg(comID id, char *msg)
213 {
214   if (msg) {  
215       if(id.isAllToAll)
216           PeMesh1->UnpackAndInsertAll(msg, 1, &MyPe);
217       else
218           PeMesh->UnpackAndInsert(msg);
219   }
220
221   recvCount++;
222   if (recvCount == recvExpected) {
223       ComlibPrintf("%d recvcount=%d recvexpected = %d\n", MyPe, recvCount, recvExpected);
224       
225       char *a2amsg;
226       int a2a_len;
227       if(id.isAllToAll) {
228           a2amsg = PeMesh1->ExtractAndPackAll(id, 1, &a2a_len);
229           CmiSetHandler(a2amsg, CkpvAccess(ProcHandle));
230           CmiReference(a2amsg);
231           CmiSyncListSendAndFree(cvecSize, gcolVector, a2a_len, a2amsg);   
232           ProcManyMsg(id, a2amsg);
233           return;
234       }
235
236       for (int rowcount=0; rowcount < cvecSize; rowcount++) {
237           int nextpe = colVector[rowcount];
238                     
239           int gnextpe = nextpe;
240           int *pelist=&gnextpe;
241           
242           ComlibPrintf("Before gmap %d\n", nextpe);
243           
244           gmap(nextpe);
245
246           ComlibPrintf("After gmap %d\n", nextpe);
247           
248           ComlibPrintf("%d:sending message to %d of row %d\n", MyPe, nextpe, 
249                        rowcount);
250           
251           GRIDSENDFN(id, 0, 1, 1, pelist, CkpvAccess(ProcHandle), nextpe);
252       }
253       
254       LocalProcMsg(id);
255   }
256 }
257
258 void GridRouter::DummyEP(comID id, int magic)
259 {
260   if (magic == 1) {
261         //ComlibPrintf("%d dummy calling lp\n", MyPe);
262         LocalProcMsg(id);
263   }
264   else {
265         //ComlibPrintf("%d dummy calling recv\n", MyPe);
266         RecvManyMsg(id, NULL);
267   }
268 }
269
270 void GridRouter:: ProcManyMsg(comID id, char *m)
271 {
272     if(id.isAllToAll)
273         PeMesh2->UnpackAndInsertAll(m, 1, &MyPe);
274     else
275         PeMesh->UnpackAndInsert(m);
276     //ComlibPrintf("%d proc calling lp\n");
277     
278     LocalProcMsg(id);
279 }
280
281 void GridRouter:: LocalProcMsg(comID id)
282 {
283     LPMsgCount++;
284     PeMesh->ExtractAndDeliverLocalMsgs(MyPe);
285     PeMesh2->ExtractAndDeliverLocalMsgs(MyPe);
286     
287     ComlibPrintf("%d local procmsg called\n", MyPe);
288     if (LPMsgCount==LPMsgExpected) {
289         PeMesh->Purge();
290         PeMesh2->Purge();
291         
292         InitVars();
293         Done(id);
294     }
295 }
296
297 Router * newgridobject(int n, int me)
298 {
299   Router *obj=new GridRouter(n, me);
300   return(obj);
301 }
302
303 void GridRouter :: SetMap(int *pes)
304 {
305     
306   gpes=pes;
307   
308   if(!gpes)
309       return;
310
311   int count = 0;
312
313   for(count = 0; count < rvecSize; count ++)
314       growVector[count] = gpes[rowVector[count]];
315   
316   for(count = 0; count < cvecSize; count ++)
317       gcolVector[count] = gpes[colVector[count]];
318 }
319