make MeshStreamer a ck-lib library
[charm.git] / src / libs / ck-libs / MeshStreamer / MeshStreamer.h
1 #ifndef _MESH_STREAMER_H_
2 #define _MESH_STREAMER_H_
3
4 #include "MeshStreamer.decl.h"
5
6 // allocate more total buffer space then the maximum buffering limit but flush upon
7 // reaching totalBufferCapacity_
8 #define BUCKET_SIZE_FACTOR 4
9
10
11 //#define DEBUG_STREAMER 1
12
13 enum MeshStreamerMessageType {PlaneMessage, ColumnMessage, PersonalizedMessage};
14
15 /*
16 class LocalMessage : public CMessage_LocalMessage {
17 public:
18     int numDataItems; 
19     int dataItemSize; 
20     char *data;
21
22     LocalMessage(int dataItemSizeInBytes) {
23         numDataItems = 0; 
24         dataItemSize = dataItemSizeInBytes; 
25     }
26
27     int addDataItem(void *dataItem) {
28         CmiMemcpy(&data[numDataItems * dataItemSize], dataItem, dataItemSize);
29         return ++numDataItems; 
30     } 
31
32     void *getDataItem(int index) {
33         return (void *) (&data[index * dataItemSize]);  
34     }
35
36 };
37 */
38
39 template<class dtype>
40 class MeshStreamerMessage : public CMessage_MeshStreamerMessage<dtype> {
41 public:
42     int numDataItems;
43     int dataItemSize;
44     int *destinationPes;
45     dtype *data;
46
47     MeshStreamerMessage(): numDataItems(0) {}   
48
49     int addDataItem(dtype &dataItem) {
50         data[numDataItems] = dataItem;
51         return ++numDataItems; 
52     }
53
54     void markDestination(int index, int destinationPe) {
55         destinationPes[index] = destinationPe;
56     }
57
58     dtype getDataItem(int index) {
59         return data[index];
60     }
61 };
62
63 template <class dtype>
64 class MeshStreamerClient : public Group {
65  public:
66      MeshStreamerClient();
67      virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
68
69 };
70
71 template <class dtype>
72 class MeshStreamer : public Group {
73
74 private:
75     int bucketSize_; 
76     int totalBufferCapacity_;
77     int numDataItemsBuffered_;
78
79     int numNodes_; 
80     int numRows_; 
81     int numColumns_; 
82     int numPlanes_; 
83     int planeSize_;
84
85     CProxy_MeshStreamerClient<dtype> clientProxy_;
86
87     int myNodeIndex_;
88     int myPlaneIndex_;
89     int myColumnIndex_; 
90     int myRowIndex_;
91
92     MeshStreamerMessage<dtype> **personalizedBuffers_; 
93     MeshStreamerMessage<dtype> **columnBuffers_; 
94     MeshStreamerMessage<dtype> **planeBuffers_;
95
96     void determineLocation(const int destinationPe, int &row, int &column, 
97         int &plane, MeshStreamerMessageType &msgType);
98
99     void storeMessage(MeshStreamerMessage<dtype> **messageBuffers, 
100         const int bucketIndex, const int destinationPe, 
101         const int rowIndex, const int columnIndex, 
102         const int planeIndex,
103         const MeshStreamerMessageType msgType, dtype &data);
104
105     void flushLargestBucket(MeshStreamerMessage<dtype> **messageBuffers,
106                             const int numBuffers, const int myIndex, 
107                             const int dimensionFactor);
108 public:
109
110     MeshStreamer(int totalBufferCapacity, int numRows, 
111                  int numColumns, int numPlanes, 
112                  const CProxy_MeshStreamerClient<dtype> &clientProxy);
113
114     ~MeshStreamer();
115
116     void insertData(dtype &, int); 
117     void receiveAggregateData(MeshStreamerMessage<dtype> *msg);
118     void receivePersonalizedData(MeshStreamerMessage<dtype> *msg);
119
120     void flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, int);
121     void flushDirect();
122 };
123
124 template <class dtype>
125 MeshStreamerClient<dtype>::MeshStreamerClient() {}
126
127 template <class dtype>
128 void MeshStreamerClient<dtype>::receiveCombinedData(MeshStreamerMessage<dtype> *msg) {
129   CkError("Default implementation of receiveCombinedData should never be called\n");
130   delete msg;
131 }
132
133 template <class dtype>
134 MeshStreamer<dtype>::MeshStreamer(int totalBufferCapacity, int numRows, 
135                            int numColumns, int numPlanes, 
136                            const CProxy_MeshStreamerClient<dtype> &clientProxy) {
137   // limit total number of messages in system to totalBufferCapacity
138   //   but allocate a factor BUCKET_SIZE_FACTOR more space to take
139   //   advantage of nonuniform filling of buckets
140   // the buffers for your own column and plane are never used
141   bucketSize_ = BUCKET_SIZE_FACTOR * totalBufferCapacity / (numRows + numColumns + numPlanes - 2); 
142   totalBufferCapacity_ = totalBufferCapacity;
143   numDataItemsBuffered_ = 0; 
144   numRows_ = numRows; 
145   numColumns_ = numColumns;
146   numPlanes_ = numPlanes; 
147   numNodes_ = CkNumPes(); 
148   clientProxy_ = clientProxy; 
149
150   personalizedBuffers_ = new MeshStreamerMessage<dtype> *[numRows];
151   for (int i = 0; i < numRows; i++) {
152     personalizedBuffers_[i] = NULL; 
153   }
154
155   columnBuffers_ = new MeshStreamerMessage<dtype> *[numColumns];
156   for (int i = 0; i < numColumns; i++) {
157     columnBuffers_[i] = NULL; 
158   }
159
160   planeBuffers_ = new MeshStreamerMessage<dtype> *[numPlanes]; 
161   for (int i = 0; i < numPlanes; i++) {
162     planeBuffers_[i] = NULL; 
163   }
164
165   // determine plane, column, and row location of this node
166   myNodeIndex_ = CkMyPe();
167   planeSize_ = numRows_ * numColumns_; 
168   myPlaneIndex_ = myNodeIndex_ / planeSize_; 
169   int indexWithinPlane = myNodeIndex_ - myPlaneIndex_ * planeSize_;
170   myRowIndex_ = indexWithinPlane / numColumns_;
171   myColumnIndex_ = indexWithinPlane - myRowIndex_ * numColumns_; 
172
173 }
174
175 template <class dtype>
176 MeshStreamer<dtype>::~MeshStreamer() {
177
178   for (int i = 0; i < numRows_; i++)
179       delete personalizedBuffers_[i]; 
180
181   for (int i = 0; i < numColumns_; i++)
182       delete columnBuffers_[i]; 
183
184   for (int i = 0; i < numPlanes_; i++)
185       delete planeBuffers_[i]; 
186
187   delete[] personalizedBuffers_;
188   delete[] columnBuffers_;
189   delete[] planeBuffers_; 
190
191 }
192
193 template <class dtype>
194 void MeshStreamer<dtype>::determineLocation(const int destinationPe, int &rowIndex, 
195                                      int &columnIndex, int &planeIndex, 
196                                      MeshStreamerMessageType &msgType) {
197   
198   int nodeIndex, indexWithinPlane; 
199
200   nodeIndex = destinationPe;
201   planeIndex = nodeIndex / planeSize_; 
202   if (planeIndex != myPlaneIndex_) {
203     msgType = PlaneMessage;     
204   }
205   else {
206     indexWithinPlane = nodeIndex - planeIndex * planeSize_;
207     rowIndex = indexWithinPlane / numColumns_;
208     columnIndex = indexWithinPlane - rowIndex * numColumns_; 
209     if (columnIndex != myColumnIndex_) {
210       msgType = ColumnMessage; 
211     }
212     else {
213       msgType = PersonalizedMessage;
214     }
215   }
216
217 }
218
219 template <class dtype>
220 void MeshStreamer<dtype>::storeMessage(MeshStreamerMessage<dtype> **messageBuffers, 
221                                 const int bucketIndex, const int destinationPe, 
222                                 const int rowIndex, const int columnIndex, 
223                                 const int planeIndex, 
224                                 const MeshStreamerMessageType msgType, 
225                                 dtype &data) {
226
227   // allocate new message if necessary
228   if (messageBuffers[bucketIndex] == NULL) {
229     int dataSize = bucketSize_;  
230     if (msgType == PersonalizedMessage) {
231       messageBuffers[bucketIndex] = 
232         new (0, dataSize) MeshStreamerMessage<dtype>;
233     }
234     else {
235       messageBuffers[bucketIndex] = 
236         new (bucketSize_, dataSize) MeshStreamerMessage<dtype>;
237     }
238 #ifdef DEBUG_STREAMER
239     CkAssert(messageBuffers[bucketIndex] != NULL);
240 #endif
241   }
242   
243   MeshStreamerMessage<dtype> *destinationBucket = messageBuffers[bucketIndex];
244   
245   int numBuffered = destinationBucket->addDataItem(data); 
246   if (msgType != PersonalizedMessage) {
247     destinationBucket->markDestination(numBuffered-1, destinationPe);
248   }
249   numDataItemsBuffered_++;
250   // copy data into message and send if buffer is full
251   if (numBuffered == bucketSize_) {
252     int destinationIndex;
253     CProxy_MeshStreamer<dtype> thisProxy(thisgroup);
254     switch (msgType) {
255
256     case PlaneMessage:
257       destinationIndex = 
258         myNodeIndex_ + (planeIndex - myPlaneIndex_) * planeSize_;      
259       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
260       break;
261     case ColumnMessage:
262       destinationIndex = myNodeIndex_ + (columnIndex - myColumnIndex_);
263       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
264       break;
265     case PersonalizedMessage:
266       destinationIndex = myNodeIndex_ + (rowIndex - myRowIndex_) * numColumns_;
267       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);      
268       //      thisProxy[destinationIndex].receivePersonalizedData(destinationBucket);
269       break;
270     default: 
271       CkError("Incorrect MeshStreamer message type\n");
272       break;
273     }
274     messageBuffers[bucketIndex] = NULL;
275     numDataItemsBuffered_ -= numBuffered; 
276   }
277
278   if (numDataItemsBuffered_ == totalBufferCapacity_) {
279
280     flushLargestBucket(personalizedBuffers_, numRows_, myRowIndex_, numColumns_);
281     flushLargestBucket(columnBuffers_, numColumns_, myColumnIndex_, 1);
282     flushLargestBucket(planeBuffers_, numPlanes_, myPlaneIndex_, planeSize_);
283
284   }
285
286 }
287
288 template <class dtype>
289 void MeshStreamer<dtype>::insertData(dtype &data, int destinationPe) {
290
291   int planeIndex, columnIndex, rowIndex; // location of destination
292   int indexWithinPlane; 
293
294   MeshStreamerMessageType msgType; 
295
296   determineLocation(destinationPe, rowIndex, columnIndex, planeIndex, msgType);
297
298   // determine which array of buffers is appropriate for this message
299   MeshStreamerMessage<dtype> **messageBuffers;
300   int bucketIndex; 
301
302   switch (msgType) {
303   case PlaneMessage:
304     messageBuffers = planeBuffers_; 
305     bucketIndex = planeIndex; 
306     break;
307   case ColumnMessage:
308     messageBuffers = columnBuffers_; 
309     bucketIndex = columnIndex; 
310     break;
311   case PersonalizedMessage:
312     messageBuffers = personalizedBuffers_; 
313     bucketIndex = rowIndex; 
314     break;
315   default: 
316     CkError("Unrecognized MeshStreamer message type\n");
317     break;
318   }
319
320   storeMessage(messageBuffers, bucketIndex, destinationPe, rowIndex, 
321                columnIndex, planeIndex, msgType, data);
322 }
323
324 template <class dtype>
325 void MeshStreamer<dtype>::receiveAggregateData(MeshStreamerMessage<dtype> *msg) {
326
327   int rowIndex, columnIndex, planeIndex, destinationPe; 
328   MeshStreamerMessageType msgType;   
329
330   for (int i = 0; i < msg->numDataItems; i++) {
331     destinationPe = msg->destinationPes[i];
332     dtype dataItem = msg->getDataItem(i);
333     determineLocation(destinationPe, rowIndex, columnIndex, 
334                       planeIndex, msgType);
335 #ifdef DEBUG_STREAMER
336     CkAssert(planeIndex == myPlaneIndex_);
337
338     if (msgType == PersonalizedMessage) {
339       CkAssert(columnIndex == myColumnIndex_);
340     }
341 #endif    
342
343     MeshStreamerMessage<dtype> **messageBuffers;
344     int bucketIndex; 
345
346     switch (msgType) {
347     case ColumnMessage:
348       messageBuffers = columnBuffers_; 
349       bucketIndex = columnIndex; 
350       break;
351     case PersonalizedMessage:
352       messageBuffers = personalizedBuffers_; 
353       bucketIndex = rowIndex; 
354       break;
355     default: 
356       CkError("Incorrect MeshStreamer message type\n");
357       break;
358     }
359
360     storeMessage(messageBuffers, bucketIndex, destinationPe, rowIndex, 
361                  columnIndex, planeIndex, msgType, dataItem);
362     
363   }
364
365   delete msg;
366
367 }
368
369 /*
370 void MeshStreamer::receivePersonalizedData(MeshStreamerMessage *msg) {
371
372   // sort data items into messages for each core on this node
373
374   LocalMessage *localMsgs[numPesPerNode_];
375   int dataSize = bucketSize_ * dataItemSize_;
376
377   for (int i = 0; i < numPesPerNode_; i++) {
378     localMsgs[i] = new (dataSize) LocalMessage(dataItemSize_);
379   }
380
381   int destinationPe;
382   for (int i = 0; i < msg->numDataItems; i++) {
383
384     destinationPe = msg->destinationPes[i]; 
385     void *dataItem = msg->getDataItem(i);   
386     localMsgs[destinationPe % numPesPerNode_]->addDataItem(dataItem);
387
388   }
389
390   for (int i = 0; i < numPesPerNode_; i++) {
391     if (localMsgs[i]->numDataItems > 0) {
392       clientProxy_[myNodeIndex_ * numPesPerNode_ + i].receiveCombinedData(localMsgs[i]);
393     }
394     else {
395       delete localMsgs[i];
396     }
397   }
398
399   delete msg; 
400
401 }
402 */
403
404 template <class dtype>
405 void MeshStreamer<dtype>::flushLargestBucket(MeshStreamerMessage<dtype> **messageBuffers,
406                                       const int numBuffers, const int myIndex, 
407                                       const int dimensionFactor) {
408
409   int flushIndex, maxSize, destinationIndex;
410   MeshStreamerMessage<dtype> *destinationBucket; 
411   maxSize = 0;
412   for (int i = 0; i < numBuffers; i++) {
413     if (messageBuffers[i] != NULL && messageBuffers[i]->numDataItems > maxSize) {
414       maxSize = messageBuffers[i]->numDataItems;
415       flushIndex = i;
416     } 
417   }
418   if (maxSize > 0) {
419     destinationBucket = messageBuffers[flushIndex];
420     destinationIndex = myNodeIndex_ + (flushIndex - myIndex) * dimensionFactor;
421     if (destinationBucket != NULL) {
422       numDataItemsBuffered_ -= destinationBucket->numDataItems;
423     }
424     if (messageBuffers == personalizedBuffers_) {
425       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);
426     }
427     else {
428       CProxy_MeshStreamer<dtype> thisProxy(thisgroup);
429       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
430     }
431     messageBuffers[flushIndex] = NULL;
432   }
433 }
434
435 template <class dtype>
436 void MeshStreamer<dtype>::flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, int numBuffers)
437 {
438
439     for (int i = 0; i < numBuffers; i++) {
440        if(messageBuffers[i] == NULL)
441            continue;
442        //flush all messages in i bucket
443        numDataItemsBuffered_ -= messageBuffers[i]->numDataItems;
444        if (messageBuffers == personalizedBuffers_) {
445          int destinationPe = myNodeIndex_ + (i - myRowIndex_) * numColumns_; 
446          clientProxy_[destinationPe].receiveCombinedData(messageBuffers[i]);
447        }
448        else {
449          for (int j = 0; j < messageBuffers[i]->numDataItems; j++) {
450            MeshStreamerMessage<dtype> *directMsg = 
451              new (0, 1) MeshStreamerMessage<dtype>;
452 #ifdef DEBUG_STREAMER
453            CkAssert(directMsg != NULL);
454 #endif
455            int destinationPe = messageBuffers[i]->destinationPes[j]; 
456            dtype dataItem = messageBuffers[i]->getDataItem(j);   
457            directMsg->addDataItem(dataItem);
458            clientProxy_[destinationPe].receiveCombinedData(directMsg);
459          }
460          delete messageBuffers[i];
461        }
462        messageBuffers[i] = NULL;
463     }
464
465 #ifdef DEBUG_STREAMER
466     CkPrintf("[%d] numDataItemsBuffered_: %d\n", CkMyPe(), numDataItemsBuffered_);
467     CkAssert(numDataItemsBuffered_ == 0); 
468 #endif
469 }
470
471 template <class dtype>
472 void MeshStreamer<dtype>::flushDirect(){
473     flushBuckets(planeBuffers_, numPlanes_);
474     flushBuckets(columnBuffers_, numColumns_);
475     flushBuckets(personalizedBuffers_, numRows_);
476 }
477
478 #define CK_TEMPLATES_ONLY
479 #include "MeshStreamer.def.h"
480 #undef CK_TEMPLATES_ONLY
481
482 #endif