ck-ldb: the API for the virtual work function changed
[charm.git] / src / ck-ldb / MetisLB.C
1 /*****************************************************************************
2  * $Source$
3  * $Author$
4  * $Date$
5  * $Revision$
6  *****************************************************************************/
7
8 /**
9  * \addtogroup CkLdb
10 */
11 /*@{*/
12
13 #include <charm++.h>
14
15 #include "cklists.h"
16
17 #include "MetisLB.h"
18
19 CreateLBFunc_Def(MetisLB, "Use Metis(tm) to partition object graph")
20
21 MetisLB::MetisLB(const CkLBOptions &opt): CentralLB(opt)
22 {
23   lbname = "MetisLB";
24   if (CkMyPe() == 0)
25     CkPrintf("[%d] MetisLB created\n",CkMyPe());
26 }
27
28 static void printStats(int count, int numobjs, double *cputimes, 
29                        int **comm, int *map)
30 {
31   int i, j;
32   double *petimes = new double[count];
33   for(i=0;i<count;i++) {
34     petimes[i] = 0.0;
35   }
36   for(i=0;i<numobjs;i++) {
37     petimes[map[i]] += cputimes[i];
38   }
39   double maxpe = petimes[0], minpe = petimes[0];
40   CkPrintf("\tPE\tTimexSpeed\n");
41   for(i=0;i<count;i++) {
42     CkPrintf("\t%d\t%lf\n",i,petimes[i]);
43     if(maxpe < petimes[i])
44       maxpe = petimes[i];
45     if(minpe > petimes[i])
46       minpe = petimes[i];
47   }
48   delete[] petimes;
49   CkPrintf("\tLoad Imbalance=%lf seconds\n", maxpe-minpe);
50   int ncomm = 0;
51   for(i=0;i<numobjs;i++) {
52     for(j=0;j<numobjs;j++) {
53       if(map[i] != map[j])
54         ncomm += comm[i][j];
55     }
56   }
57   CkPrintf("\tCommunication (off proc msgs) = %d\n", ncomm/2);
58 }
59
60 extern "C" void METIS_PartGraphRecursive(int*, int*, int*, int*, int*,
61                                          int*, int*, int*, int*,
62                                          int*, int*);
63 extern "C" void METIS_PartGraphKway(int*, int*, int*, int*, int*,
64                                     int*, int*, int*, int*,
65                                     int*, int*);
66 extern "C" void METIS_PartGraphVKway(int*, int*, int*, int*, int*,
67                                      int*, int*, int*, int*,
68                                      int*, int*);
69
70 // the following are to compute a partitioning with a given partition weights
71 // "W" means giving weights
72 extern "C" void METIS_WPartGraphRecursive(int*, int*, int*, int*, int*,
73                                           int*, int*, int*, float*, int*,
74                                           int*, int*);
75 extern "C" void METIS_WPartGraphKway(int*, int*, int*, int*, int*,
76                                      int*, int*, int*, float*, int*,
77                                      int*, int*);
78
79 // the following are for multiple constraint partition "mC"
80 extern "C" void METIS_mCPartGraphRecursive(int*, int*, int*, int*, int*, int*,
81                                          int*, int*, int*, int*,
82                                          int*, int*);
83 extern "C" void METIS_mCPartGraphKway(int*, int*, int*, int*, int*, int*,
84                                     int*, int*, int*, int*, int*,
85                                     int*, int*);
86
87 void MetisLB::work(LDStats* stats)
88 {
89   if (_lb_args.debug() >= 2) {
90     CkPrintf("[%d] In MetisLB Strategy...\n", CkMyPe());
91   }
92   int i, j, m;
93   int option = 0;
94
95   stats->makeCommHash();
96
97   int n_pes = stats->count;
98   int numobjs = stats->n_objs;
99
100   removeNonMigratable(stats, n_pes);
101
102   // allocate space for the computing data
103   double *objtime = new double[numobjs];
104   int *objwt = new int[numobjs];
105   int *origmap = new int[numobjs];
106   LDObjHandle *handles = new LDObjHandle[numobjs];
107   for(i=0;i<numobjs;i++) {
108     objtime[i] = 0.0;
109     objwt[i] = 0;
110     origmap[i] = 0;
111   }
112
113   for (i=0; i<stats->n_objs; i++) {
114       LDObjData &odata = stats->objData[i];
115       if (!odata.migratable) 
116         CmiAbort("MetisLB doesnot dupport nonmigratable object.\n");
117       /*
118       origmap[odata[i].id.id[0]] = j;
119       cputime[odata[i].id.id[0]] = odata[i].cpuTime;
120       handles[odata[i].id.id[0]] = odata[i].handle;
121       */
122       int frompe = stats->from_proc[i];
123       origmap[i] = frompe;
124       objtime[i] = odata.wallTime*stats->procs[frompe].pe_speed;
125       handles[i] = odata.handle;
126   }
127
128   // to convert the weights on vertices to integers
129   double max_objtime = objtime[0];
130   for(i=0; i<numobjs; i++) {
131     if(max_objtime < objtime[i])
132       max_objtime = objtime[i];
133   }
134   double ratio = 1000.0/max_objtime;
135   for(i=0; i<numobjs; i++) {
136       objwt[i] = (int)(objtime[i]*ratio);
137   }
138   int **comm = new int*[numobjs];
139   for (i=0; i<numobjs; i++) {
140     comm[i] = new int[numobjs];
141     for (j=0; j<numobjs; j++)  {
142       comm[i][j] = 0;
143     }
144   }
145
146   const int csz = stats->n_comm;
147   for(i=0; i<csz; i++) {
148       LDCommData &cdata = stats->commData[i];
149       if(!cdata.from_proc() && cdata.receiver.get_type() == LD_OBJ_MSG)
150       {
151         int senderID = stats->getHash(cdata.sender);
152         int recverID = stats->getHash(cdata.receiver.get_destObj());
153         if (stats->complete_flag == 0 && recverID == -1) continue;
154         CmiAssert(senderID < numobjs && senderID >= 0);
155         CmiAssert(recverID < numobjs && recverID >= 0);
156         comm[senderID][recverID] += cdata.messages;
157         comm[recverID][senderID] += cdata.messages;
158       }
159       else if (cdata.receiver.get_type() == LD_OBJLIST_MSG) {
160         int nobjs;
161         LDObjKey *objs = cdata.receiver.get_destObjs(nobjs);
162         int senderID = stats->getHash(cdata.sender);
163         for (j=0; j<nobjs; j++) {
164            int recverID = stats->getHash(objs[j]);
165            if((senderID == -1)||(recverID == -1))
166               if (_lb_args.migObjOnly()) continue;
167               else CkAbort("Error in search\n");
168            comm[senderID][recverID] += cdata.messages;
169            comm[recverID][senderID] += cdata.messages;
170         }
171       }
172     }
173
174 // ignore messages sent from an object to itself
175   for (i=0; i<numobjs; i++)
176     comm[i][i] = 0;
177
178   // construct the graph in CSR format
179   int *xadj = new int[numobjs+1];
180   int numedges = 0;
181   for(i=0;i<numobjs;i++) {
182     for(j=0;j<numobjs;j++) {
183       if(comm[i][j] != 0)
184         numedges++;
185     }
186   }
187   int *adjncy = new int[numedges];
188   int *edgewt = new int[numedges];
189   xadj[0] = 0;
190   int count4all = 0;
191   for (i=0; i<numobjs; i++) {
192     for (j=0; j<numobjs; j++) { 
193       if (comm[i][j] != 0) { 
194         adjncy[count4all] = j;
195         edgewt[count4all++] = comm[i][j];
196       }
197     }
198     xadj[i+1] = count4all;
199   }
200
201   if (_lb_args.debug() >= 2) {
202   CkPrintf("Pre-LDB Statistics step %d\n", step());
203   printStats(n_pes, numobjs, objtime, comm, origmap);
204   }
205
206   int wgtflag = 3; // Weights both on vertices and edges
207   int numflag = 0; // C Style numbering
208   int options[5];
209   options[0] = 0;
210   int edgecut;
211   int *newmap;
212   int sameMapFlag = 1;
213
214   if (n_pes < 1) {
215     CkPrintf("error: Number of Pe less than 1!");
216   }
217   else if (n_pes == 1) {
218     newmap = origmap;
219     sameMapFlag = 1;
220   }
221   else {
222     sameMapFlag = 0;
223     newmap = new int[numobjs];
224     //for(i=0;i<(numobjs+1);i++)
225       //xadj[i] = 0;
226     //delete[] edgewt;
227     //edgewt = 0;
228     //wgtflag = 2;
229     // CkPrintf("before calling Metis functions. option is %d.\n", option);
230     if (0 == option) {
231
232 /*  I intended to follow the instruction in the Metis 4.0 manual
233     which said that METIS_PartGraphKway is preferable to 
234     METIS_PartGraphRecursive, when nparts > 8.
235     However, it turned out that there is bug in METIS_PartGraphKway,
236     and the function seg faulted when nparts = 4 or 9.
237     So right now I just comment that function out and always use the other one.
238 */
239 /*
240       if (n_pes > 8)
241         METIS_PartGraphKway(&numobjs, xadj, adjncy, objwt, edgewt, 
242                             &wgtflag, &numflag, &n_pes, options,
243                             &edgecut, newmap);
244       else
245         METIS_PartGraphRecursive(&numobjs, xadj, adjncy, objwt, edgewt, 
246                                  &wgtflag, &numflag, &n_pes, options,
247                                  &edgecut, newmap);
248 */
249       if (_lb_args.debug() >= 1)
250         CkPrintf("[%d] calling METIS_PartGraphRecursive.\n", CkMyPe());
251       METIS_PartGraphRecursive(&numobjs, xadj, adjncy, objwt, edgewt,
252                                  &wgtflag, &numflag, &n_pes, options,
253                                  &edgecut, newmap);
254       if (_lb_args.debug() >= 1)
255         CkPrintf("[%d] after calling Metis functions.\n", CkMyPe());
256     }
257     else if (WEIGHTED == option) {
258       CkPrintf("unepected\n");
259       float maxtotal_walltime = stats->procs[0].total_walltime;
260       for (m = 1; m < n_pes; m++) {
261         if (maxtotal_walltime < stats->procs[m].total_walltime)
262           maxtotal_walltime = stats->procs[m].total_walltime;
263       }
264       float totaltimeAllPe = 0.0;
265       for (m = 0; m < n_pes; m++) {
266         totaltimeAllPe += stats->procs[m].pe_speed * 
267           (maxtotal_walltime-stats->procs[m].bg_walltime);
268       }
269       // set up the different weights
270       float *tpwgts = new float[n_pes];
271       for (m = 0; m < n_pes; m++) {
272         tpwgts[m] = stats->procs[m].pe_speed * 
273           (maxtotal_walltime-stats->procs[m].bg_walltime) / totaltimeAllPe;
274       }
275       if (n_pes > 8)
276         METIS_WPartGraphKway(&numobjs, xadj, adjncy, objwt, edgewt, 
277                              &wgtflag, &numflag, &n_pes, tpwgts, options,
278                              &edgecut, newmap);
279       else
280         METIS_WPartGraphRecursive(&numobjs, xadj, adjncy, objwt, edgewt, 
281                                   &wgtflag, &numflag, &n_pes, tpwgts, options,
282                                   &edgecut, newmap);
283       delete[] tpwgts;
284     }
285     else if (MULTI_CONSTRAINT == option) {
286       CkPrintf("Metis load balance strategy: ");
287       CkPrintf("multiple constraints not implemented yet.\n");
288     }
289   }
290   if (_lb_args.debug() >= 2) {
291   CkPrintf("Post-LDB Statistics step %d\n", step());
292   printStats(n_pes, numobjs, objtime, comm, newmap);
293   }
294
295   for(i=0;i<numobjs;i++)
296     delete[] comm[i];
297   delete[] comm;
298   delete[] objtime;
299   delete[] xadj;
300   delete[] adjncy;
301   if(objwt) delete[] objwt;
302   if(edgewt) delete[] edgewt;
303         
304   /*CkPrintf("obj-proc mapping\n");
305         for(i=0;i<numobjs;i++)
306                 CkPrintf(" %d,%d ",i,newmap[i]);
307   */
308   if(!sameMapFlag) {
309     for(i=0; i<numobjs; i++) {
310       if(origmap[i] != newmap[i]) {
311         CmiAssert(stats->from_proc[i] == origmap[i]);
312         stats->to_proc[i] =  newmap[i];
313         if (_lb_args.debug() >= 3)
314             CkPrintf("[%d] Obj %d migrating from %d to %d\n", CkMyPe(),i,stats->from_proc[i],stats->to_proc[i]);
315       }
316     }
317   }
318         
319         //CkPrintf("chking wts on each partition...\n");
320
321 /*
322         int avg=0;
323         int *chkwt = new int[n_pes];
324         for(i=0; i<n_pes; i++)
325                 chkwt[i]=0;
326         //totalwt=0;
327         for(i=0;i<numobjs;i++){
328                 chkwt[newmap[i]] += objwt[i];
329                 avg += objwt[i];
330                 
331         }
332         
333         
334         for(i=0; i<n_pes; i++)
335                 CkPrintf("%d -- %d\n",i,chkwt[i]);
336 */
337   delete[] origmap;
338   if(newmap != origmap)
339     delete[] newmap;
340   if (_lb_args.debug() >= 1) {
341    CkPrintf("[%d] MetisLB done! \n", CkMyPe());
342   }
343 }
344
345 #include "MetisLB.def.h"
346
347 /*@}*/