da49b1a866d073bc1c362a5883b520abbc8878cb
[charm.git] / src / ck-ldb / MetisLB.C
1 /*****************************************************************************
2  * $Source$
3  * $Author$
4  * $Date$
5  * $Revision$
6  *****************************************************************************/
7
8 /**
9  * \addtogroup CkLdb
10 */
11 /*@{*/
12
13 #include <charm++.h>
14
15 #include "cklists.h"
16
17 #include "MetisLB.h"
18
19 CreateLBFunc_Def(MetisLB, "Use Metis(tm) to partition object graph")
20
21 MetisLB::MetisLB(const CkLBOptions &opt): CentralLB(opt)
22 {
23   lbname = "MetisLB";
24   if (CkMyPe() == 0)
25     CkPrintf("[%d] MetisLB created\n",CkMyPe());
26 }
27
28 static void printStats(int count, int numobjs, double *cputimes, 
29                        int **comm, int *map)
30 {
31   int i, j;
32   double *petimes = new double[count];
33   for(i=0;i<count;i++) {
34     petimes[i] = 0.0;
35   }
36   for(i=0;i<numobjs;i++) {
37     petimes[map[i]] += cputimes[i];
38   }
39   double maxpe = petimes[0], minpe = petimes[0];
40   CkPrintf("\tPE\tTimexSpeed\n");
41   for(i=0;i<count;i++) {
42     CkPrintf("\t%d\t%lf\n",i,petimes[i]);
43     if(maxpe < petimes[i])
44       maxpe = petimes[i];
45     if(minpe > petimes[i])
46       minpe = petimes[i];
47   }
48   delete[] petimes;
49   CkPrintf("\tLoad Imbalance=%lf seconds\n", maxpe-minpe);
50   int ncomm = 0;
51   for(i=0;i<numobjs;i++) {
52     for(j=0;j<numobjs;j++) {
53       if(map[i] != map[j])
54         ncomm += comm[i][j];
55     }
56   }
57   CkPrintf("\tCommunication (off proc msgs) = %d\n", ncomm/2);
58 }
59
60 extern "C" void METIS_PartGraphRecursive(int*, int*, int*, int*, int*,
61                                          int*, int*, int*, int*,
62                                          int*, int*);
63 extern "C" void METIS_PartGraphKway(int*, int*, int*, int*, int*,
64                                     int*, int*, int*, int*,
65                                     int*, int*);
66 extern "C" void METIS_PartGraphVKway(int*, int*, int*, int*, int*,
67                                      int*, int*, int*, int*,
68                                      int*, int*);
69
70 // the following are to compute a partitioning with a given partition weights
71 // "W" means giving weights
72 extern "C" void METIS_WPartGraphRecursive(int*, int*, int*, int*, int*,
73                                           int*, int*, int*, float*, int*,
74                                           int*, int*);
75 extern "C" void METIS_WPartGraphKway(int*, int*, int*, int*, int*,
76                                      int*, int*, int*, float*, int*,
77                                      int*, int*);
78
79 // the following are for multiple constraint partition "mC"
80 extern "C" void METIS_mCPartGraphRecursive(int*, int*, int*, int*, int*, int*,
81                                          int*, int*, int*, int*,
82                                          int*, int*);
83 extern "C" void METIS_mCPartGraphKway(int*, int*, int*, int*, int*, int*,
84                                     int*, int*, int*, int*, int*,
85                                     int*, int*);
86
87 void MetisLB::work(BaseLB::LDStats* stats, int count)
88 {
89   if (_lb_args.debug() >= 2) {
90     CkPrintf("[%d] In MetisLB Strategy...\n", CkMyPe());
91   }
92   int i, j, m;
93   int option = 0;
94
95   stats->makeCommHash();
96
97   removeNonMigratable(stats, count);
98
99   int numobjs = stats->n_objs;
100
101   // allocate space for the computing data
102   double *objtime = new double[numobjs];
103   int *objwt = new int[numobjs];
104   int *origmap = new int[numobjs];
105   LDObjHandle *handles = new LDObjHandle[numobjs];
106   for(i=0;i<numobjs;i++) {
107     objtime[i] = 0.0;
108     objwt[i] = 0;
109     origmap[i] = 0;
110   }
111
112   for (i=0; i<stats->n_objs; i++) {
113       LDObjData &odata = stats->objData[i];
114       if (!odata.migratable) 
115         CmiAbort("MetisLB doesnot dupport nonmigratable object.\n");
116       /*
117       origmap[odata[i].id.id[0]] = j;
118       cputime[odata[i].id.id[0]] = odata[i].cpuTime;
119       handles[odata[i].id.id[0]] = odata[i].handle;
120       */
121       int frompe = stats->from_proc[i];
122       origmap[i] = frompe;
123       objtime[i] = odata.wallTime*stats->procs[frompe].pe_speed;
124       handles[i] = odata.handle;
125   }
126
127   // to convert the weights on vertices to integers
128   double max_objtime = objtime[0];
129   for(i=0; i<numobjs; i++) {
130     if(max_objtime < objtime[i])
131       max_objtime = objtime[i];
132   }
133   double ratio = 1000.0/max_objtime;
134   for(i=0; i<numobjs; i++) {
135       objwt[i] = (int)(objtime[i]*ratio);
136   }
137   int **comm = new int*[numobjs];
138   for (i=0; i<numobjs; i++) {
139     comm[i] = new int[numobjs];
140     for (j=0; j<numobjs; j++)  {
141       comm[i][j] = 0;
142     }
143   }
144
145   const int csz = stats->n_comm;
146   for(i=0; i<csz; i++) {
147       LDCommData &cdata = stats->commData[i];
148       if(!cdata.from_proc() && cdata.receiver.get_type() == LD_OBJ_MSG)
149       {
150         int senderID = stats->getHash(cdata.sender);
151         int recverID = stats->getHash(cdata.receiver.get_destObj());
152         if (stats->complete_flag == 0 && recverID == -1) continue;
153         CmiAssert(senderID < numobjs && senderID >= 0);
154         CmiAssert(recverID < numobjs && recverID >= 0);
155         comm[senderID][recverID] += cdata.messages;
156         comm[recverID][senderID] += cdata.messages;
157       }
158       else if (cdata.receiver.get_type() == LD_OBJLIST_MSG) {
159         int nobjs;
160         LDObjKey *objs = cdata.receiver.get_destObjs(nobjs);
161         int senderID = stats->getHash(cdata.sender);
162         for (j=0; j<nobjs; j++) {
163            int recverID = stats->getHash(objs[j]);
164            if((senderID == -1)||(recverID == -1))
165               if (_lb_args.migObjOnly()) continue;
166               else CkAbort("Error in search\n");
167            comm[senderID][recverID] += cdata.messages;
168            comm[recverID][senderID] += cdata.messages;
169         }
170       }
171     }
172
173 // ignore messages sent from an object to itself
174   for (i=0; i<numobjs; i++)
175     comm[i][i] = 0;
176
177   // construct the graph in CSR format
178   int *xadj = new int[numobjs+1];
179   int numedges = 0;
180   for(i=0;i<numobjs;i++) {
181     for(j=0;j<numobjs;j++) {
182       if(comm[i][j] != 0)
183         numedges++;
184     }
185   }
186   int *adjncy = new int[numedges];
187   int *edgewt = new int[numedges];
188   xadj[0] = 0;
189   int count4all = 0;
190   for (i=0; i<numobjs; i++) {
191     for (j=0; j<numobjs; j++) { 
192       if (comm[i][j] != 0) { 
193         adjncy[count4all] = j;
194         edgewt[count4all++] = comm[i][j];
195       }
196     }
197     xadj[i+1] = count4all;
198   }
199
200   if (_lb_args.debug() >= 2) {
201   CkPrintf("Pre-LDB Statistics step %d\n", step());
202   printStats(count, numobjs, objtime, comm, origmap);
203   }
204
205   int wgtflag = 3; // Weights both on vertices and edges
206   int numflag = 0; // C Style numbering
207   int options[5];
208   options[0] = 0;
209   int edgecut;
210   int *newmap;
211   int sameMapFlag = 1;
212
213   if (count < 1) {
214     CkPrintf("error: Number of Pe less than 1!");
215   }
216   else if (count == 1) {
217     newmap = origmap;
218     sameMapFlag = 1;
219   }
220   else {
221     sameMapFlag = 0;
222     newmap = new int[numobjs];
223     //for(i=0;i<(numobjs+1);i++)
224       //xadj[i] = 0;
225     //delete[] edgewt;
226     //edgewt = 0;
227     //wgtflag = 2;
228     // CkPrintf("before calling Metis functions. option is %d.\n", option);
229     if (0 == option) {
230
231 /*  I intended to follow the instruction in the Metis 4.0 manual
232     which said that METIS_PartGraphKway is preferable to 
233     METIS_PartGraphRecursive, when nparts > 8.
234     However, it turned out that there is bug in METIS_PartGraphKway,
235     and the function seg faulted when nparts = 4 or 9.
236     So right now I just comment that function out and always use the other one.
237 */
238 /*
239       if (count > 8)
240         METIS_PartGraphKway(&numobjs, xadj, adjncy, objwt, edgewt, 
241                             &wgtflag, &numflag, &count, options, 
242                             &edgecut, newmap);
243       else
244         METIS_PartGraphRecursive(&numobjs, xadj, adjncy, objwt, edgewt, 
245                                  &wgtflag, &numflag, &count, options, 
246                                  &edgecut, newmap);
247 */
248       if (_lb_args.debug() >= 1)
249         CkPrintf("[%d] calling METIS_PartGraphRecursive.\n", CkMyPe());
250       METIS_PartGraphRecursive(&numobjs, xadj, adjncy, objwt, edgewt,
251                                  &wgtflag, &numflag, &count, options,
252                                  &edgecut, newmap);
253       if (_lb_args.debug() >= 1)
254         CkPrintf("[%d] after calling Metis functions.\n", CkMyPe());
255     }
256     else if (WEIGHTED == option) {
257       CkPrintf("unepected\n");
258       float maxtotal_walltime = stats->procs[0].total_walltime;
259       for (m=1; m<count; m++) {
260         if (maxtotal_walltime < stats->procs[m].total_walltime)
261           maxtotal_walltime = stats->procs[m].total_walltime;
262       }
263       float totaltimeAllPe = 0.0;
264       for (m=0; m<count; m++) {
265         totaltimeAllPe += stats->procs[m].pe_speed * 
266           (maxtotal_walltime-stats->procs[m].bg_walltime);
267       }
268       // set up the different weights
269       float *tpwgts = new float[count];
270       for (m=0; m<count; m++) {
271         tpwgts[m] = stats->procs[m].pe_speed * 
272           (maxtotal_walltime-stats->procs[m].bg_walltime) / totaltimeAllPe;
273       }
274       if (count > 8)
275         METIS_WPartGraphKway(&numobjs, xadj, adjncy, objwt, edgewt, 
276                              &wgtflag, &numflag, &count, tpwgts, options, 
277                              &edgecut, newmap);
278       else
279         METIS_WPartGraphRecursive(&numobjs, xadj, adjncy, objwt, edgewt, 
280                                   &wgtflag, &numflag, &count, tpwgts, options, 
281                                   &edgecut, newmap);
282       delete[] tpwgts;
283     }
284     else if (MULTI_CONSTRAINT == option) {
285       CkPrintf("Metis load balance strategy: ");
286       CkPrintf("multiple constraints not implemented yet.\n");
287     }
288   }
289   if (_lb_args.debug() >= 2) {
290   CkPrintf("Post-LDB Statistics step %d\n", step());
291   printStats(count, numobjs, objtime, comm, newmap);
292   }
293
294   for(i=0;i<numobjs;i++)
295     delete[] comm[i];
296   delete[] comm;
297   delete[] objtime;
298   delete[] xadj;
299   delete[] adjncy;
300   if(objwt) delete[] objwt;
301   if(edgewt) delete[] edgewt;
302         
303   /*CkPrintf("obj-proc mapping\n");
304         for(i=0;i<numobjs;i++)
305                 CkPrintf(" %d,%d ",i,newmap[i]);
306   */
307   if(!sameMapFlag) {
308     for(i=0; i<numobjs; i++) {
309       if(origmap[i] != newmap[i]) {
310         CmiAssert(stats->from_proc[i] == origmap[i]);
311         stats->to_proc[i] =  newmap[i];
312         if (_lb_args.debug() >= 3)
313             CkPrintf("[%d] Obj %d migrating from %d to %d\n", CkMyPe(),i,stats->from_proc[i],stats->to_proc[i]);
314       }
315     }
316   }
317         
318         //CkPrintf("chking wts on each partition...\n");
319
320 /*
321         int avg=0;
322         int *chkwt = new int[count];
323         for(i=0;i<count;i++)
324                 chkwt[i]=0;
325         //totalwt=0;
326         for(i=0;i<numobjs;i++){
327                 chkwt[newmap[i]] += objwt[i];
328                 avg += objwt[i];
329                 
330         }
331         
332         
333         for(i=0;i<count;i++)
334                 CkPrintf("%d -- %d\n",i,chkwt[i]);
335 */
336   delete[] origmap;
337   if(newmap != origmap)
338     delete[] newmap;
339   if (_lb_args.debug() >= 1) {
340    CkPrintf("[%d] MetisLB done! \n", CkMyPe());
341   }
342 }
343
344 #include "MetisLB.def.h"
345
346 /*@}*/