430da0a717b11cac877eeae042a49b2f8a619763
[charm.git] / src / ck-ldb / MetisLB.C
1 #include <charm++.h>
2
3 #if CMK_LBDB_ON
4
5 #if CMK_STL_USE_DOT_H
6 #include <deque.h>
7 #include <queue.h>
8 #else
9 #include <deque>
10 #include <queue>
11 #endif
12
13 #include "MetisLB.h"
14 #include "MetisLB.def.h"
15
16 #if CMK_STL_USE_DOT_H
17 template class deque<CentralLB::MigrateInfo>;
18 #else
19 template class std::deque<CentralLB::MigrateInfo>;
20 #endif
21
22 void CreateMetisLB()
23 {
24   // CkPrintf("[%d] creating MetisLB %d\n",CkMyPe(),loadbalancer);
25   loadbalancer = CProxy_MetisLB::ckNew();
26   // CkPrintf("[%d] created MetisLB %d\n",CkMyPe(),loadbalancer);
27 }
28
29 MetisLB::MetisLB()
30 {
31   // CkPrintf("[%d] MetisLB created\n",CkMyPe());
32 }
33
34 CmiBool MetisLB::QueryBalanceNow(int _step)
35 {
36   // CkPrintf("[%d] Balancing on step %d\n",CkMyPe(),_step);
37   return CmiTrue;
38 }
39
40 static void printStats(int count, int numobjs, double *cputimes, 
41                        int **comm, int *map)
42 {
43   int i, j;
44   double *petimes = new double[count];
45   for(i=0;i<count;i++) {
46     petimes[i] = 0.0;
47   }
48   for(i=0;i<numobjs;i++) {
49     petimes[map[i]] += cputimes[i];
50   }
51   double maxpe = petimes[0], minpe = petimes[0];
52   CkPrintf("\tPE\tTime\n");
53   for(i=0;i<count;i++) {
54     CkPrintf("\t%d\t%lf\n",i,petimes[i]);
55     if(maxpe < petimes[i])
56       maxpe = petimes[i];
57     if(minpe > petimes[i])
58       minpe = petimes[i];
59   }
60   delete[] petimes;
61   CkPrintf("\tLoad Imbalance=%lf seconds\n", maxpe-minpe);
62   int ncomm = 0;
63   for(i=0;i<numobjs;i++) {
64     for(j=0;j<numobjs;j++) {
65       if(map[i] != map[j])
66         ncomm += comm[i][j];
67     }
68   }
69   CkPrintf("\tCommunication (off proc msgs) = %d\n", ncomm/2);
70 }
71
72 extern "C" void METIS_PartGraphKway(int*, int*, int*, int*, int*,
73                                     int*, int*, int*, int*,
74                                     int*, int*);
75 extern "C" void METIS_PartGraphRecursive(int*, int*, int*, int*, int*,
76                                     int*, int*, int*, int*,
77                                     int*, int*);
78 extern "C" void METIS_PartGraphVKway(int*, int*, int*, int*, int*,
79                                     int*, int*, int*, int*,
80                                     int*, int*);
81
82 CLBMigrateMsg* MetisLB::Strategy(CentralLB::LDStats* stats, int count)
83 {
84   // CkPrintf("[%d] MetisLB strategy\n",CkMyPe());
85
86 #if CMK_STL_USE_DOT_H
87   queue<MigrateInfo> migrateInfo;
88 #else
89   std::queue<MigrateInfo> migrateInfo;
90 #endif
91
92   int i, j;
93   int numobjs = 0;
94   for (j=0; j < count; j++) {
95     numobjs += stats[j].n_objs;
96   }
97
98   // allocate space for the computing data
99   double *cputime = new double[numobjs];
100   int *objwt = new int[numobjs];
101   int *origmap = new int[numobjs];
102   LDObjHandle *handles = new LDObjHandle[numobjs];
103   for(i=0;i<numobjs;i++) {
104     cputime[i] = 0.0;
105     objwt[i] = 0;
106     origmap[i] = 0;
107   }
108
109   for (j=0; j<count; j++) {
110     for (i=0; i<stats[j].n_objs; i++) {
111       LDObjData *odata = stats[j].objData;
112       origmap[odata[i].id.id[0]] = j;
113       cputime[odata[i].id.id[0]] = odata[i].cpuTime;
114       handles[odata[i].id.id[0]] = odata[i].handle;
115     }
116   }
117   double max_cputime = cputime[0];
118   for(i=0; i<numobjs; i++) {
119     if(max_cputime < cputime[i])
120       max_cputime = cputime[i];
121   }
122   double ratio = 1000.0/max_cputime;
123   for(i=0; i<numobjs; i++) {
124     objwt[i] = (int)(cputime[i]*ratio);
125   }
126   int **comm = new int*[numobjs];
127   for (i=0; i<numobjs; i++) {
128     comm[i] = new int[numobjs];
129     for (j=0; j<numobjs; j++)  {
130       comm[i][j] = 0;
131     }
132   }
133
134   for(j=0; j<count; j++) {
135     LDCommData *cdata = stats[j].commData;
136     const int csz = stats[j].n_comm;
137     for(i=0; i<csz; i++) {
138       if(cdata[i].from_proc || cdata[i].to_proc)
139         continue;
140       int senderID = cdata[i].sender.id[0];
141       int recverID = cdata[i].receiver.id[0];
142       comm[senderID][recverID] += cdata[i].messages;
143       comm[recverID][senderID] += cdata[i].messages;
144     }
145   }
146   // ignore messages sent from an object to itself
147   for (i=0; i<numobjs; i++)
148     comm[i][i] = 0;
149
150   // construct the graph in CSR format
151   int *xadj = new int[numobjs+1];
152   int numedges = 0;
153   for(i=0;i<numobjs;i++) {
154     for(j=0;j<numobjs;j++) {
155       if(comm[i][j] != 0)
156         numedges++;
157     }
158   }
159   int *adjncy = new int[numedges];
160   int *edgewt = new int[numedges];
161   xadj[0] = 0;
162   int count4all = 0;
163   for (i=0; i<numobjs; i++) {
164     for (j=0; j<numobjs; j++) { 
165       if (comm[i][j] != 0) { 
166         adjncy[count4all] = j;
167         edgewt[count4all++] = comm[i][j];
168       }
169     }
170     xadj[i+1] = count4all;
171   }
172
173   CkPrintf("Pre-LDB Statistics step %d\n", step());
174   printStats(count, numobjs, cputime, comm, origmap);
175
176   int wgtflag = 3; // Weights both on vertices and edges
177   int numflag = 0; // C Style numbering
178   int options[5];
179   options[0] = 0;
180   int edgecut;
181   int *newmap;
182
183   if(count > 1) {
184     newmap = new int[numobjs];
185     for(i=0;i<(numobjs+1);i++)
186       xadj[i] = 0;
187     delete[] edgewt;
188     edgewt = 0;
189     wgtflag = 2;
190     METIS_PartGraphRecursive(&numobjs, xadj, adjncy, objwt, edgewt, 
191                          &wgtflag, &numflag, &count, options, 
192                          &edgecut, newmap);
193   } else {
194     newmap = origmap;
195   }
196   CkPrintf("Post-LDB Statistics step %d\n", step());
197   printStats(count, numobjs, cputime, comm, newmap);
198
199   for(i=0;i<numobjs;i++)
200     delete[] comm[i];
201   delete[] comm;
202   delete[] cputime;
203   delete[] xadj;
204   delete[] adjncy;
205   if(objwt) delete[] objwt;
206   if(edgewt) delete[] edgewt;
207
208   for(i=0; i<numobjs; i++) {
209     if(origmap[i] != newmap[i]) {
210       MigrateInfo migrateMe;
211       migrateMe.obj = handles[i];
212       migrateMe.from_pe = origmap[i];
213       migrateMe.to_pe = newmap[i];
214       migrateInfo.push(migrateMe);
215     }
216   }
217
218   delete[] origmap;
219   if(newmap != origmap)
220     delete[] newmap;
221
222   int migrate_count=migrateInfo.size();
223   CkPrintf("Migration Count = %d\n", migrate_count);
224   CLBMigrateMsg* msg = new(&migrate_count,1) CLBMigrateMsg;
225   msg->n_moves = migrate_count;
226   for(i=0; i < migrate_count; i++) {
227     msg->moves[i] = migrateInfo.front();
228     migrateInfo.pop();
229   }
230
231   return msg;
232 }
233
234 #endif