MeshStreamer: function rename to improve readability and avoid confusion
[charm.git] / src / libs / ck-libs / MeshStreamer / MeshStreamer.h
1 #ifndef _MESH_STREAMER_H_
2 #define _MESH_STREAMER_H_
3
4 #include "MeshStreamer.decl.h"
5
6 // allocate more total buffer space then the maximum buffering limit but flush upon
7 // reaching totalBufferCapacity_
8 #define BUCKET_SIZE_FACTOR 4
9
10 //#define DEBUG_STREAMER 1
11
12 enum MeshStreamerMessageType {PlaneMessage, ColumnMessage, PersonalizedMessage};
13
14 /*
15 class LocalMessage : public CMessage_LocalMessage {
16 public:
17     int numDataItems; 
18     int dataItemSize; 
19     char *data;
20
21     LocalMessage(int dataItemSizeInBytes) {
22         numDataItems = 0; 
23         dataItemSize = dataItemSizeInBytes; 
24     }
25
26     int addDataItem(void *dataItem) {
27         CmiMemcpy(&data[numDataItems * dataItemSize], dataItem, dataItemSize);
28         return ++numDataItems; 
29     } 
30
31     void *getDataItem(int index) {
32         return (void *) (&data[index * dataItemSize]);  
33     }
34
35 };
36 */
37
38 template<class dtype>
39 class MeshStreamerMessage : public CMessage_MeshStreamerMessage<dtype> {
40 public:
41     int numDataItems;
42     int *destinationPes;
43     dtype *data;
44
45     MeshStreamerMessage(): numDataItems(0) {}   
46
47     int addDataItem(const dtype &dataItem) {
48         data[numDataItems] = dataItem;
49         return ++numDataItems; 
50     }
51
52     void markDestination(const int index, const int destinationPe) {
53         destinationPes[index] = destinationPe;
54     }
55
56     dtype getDataItem(const int index) {
57         return data[index];
58     }
59 };
60
61 template <class dtype>
62 class MeshStreamerClient : public Group {
63  public:
64      MeshStreamerClient();
65      virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
66
67 };
68
69 template <class dtype>
70 class MeshStreamer : public Group {
71
72 private:
73     int bucketSize_; 
74     int totalBufferCapacity_;
75     int numDataItemsBuffered_;
76
77     int numNodes_; 
78     int numRows_; 
79     int numColumns_; 
80     int numPlanes_; 
81     int planeSize_;
82
83     CProxy_MeshStreamerClient<dtype> clientProxy_;
84
85     int myNodeIndex_;
86     int myPlaneIndex_;
87     int myColumnIndex_; 
88     int myRowIndex_;
89
90     CkCallback   userCallback_;
91
92     MeshStreamerMessage<dtype> **personalizedBuffers_; 
93     MeshStreamerMessage<dtype> **columnBuffers_; 
94     MeshStreamerMessage<dtype> **planeBuffers_;
95
96     void determineLocation(const int destinationPe, int &row, int &column, 
97         int &plane, MeshStreamerMessageType &msgType);
98
99     void storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
100         const int bucketIndex, const int destinationPe, 
101         const int rowIndex, const int columnIndex, 
102         const int planeIndex,
103         const MeshStreamerMessageType msgType, const dtype &dataItem);
104
105     void flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
106                             const int numBuffers, const int myIndex, 
107                             const int dimensionFactor);
108 public:
109
110     MeshStreamer(int totalBufferCapacity, int numRows, 
111                  int numColumns, int numPlanes, 
112                  const CProxy_MeshStreamerClient<dtype> &clientProxy);
113     ~MeshStreamer();
114
115       // entry
116     void insertData(const dtype &dataItem, const int destinationPe); 
117     void receiveAggregateData(MeshStreamerMessage<dtype> *msg);
118     // void receivePersonalizedData(MeshStreamerMessage<dtype> *msg);
119
120     void flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers);
121     void flushDirect();
122
123       // non entry
124     void associateCallback(CkCallback &cb) { 
125               userCallback_ = cb;
126               CkStartQD(CkCallback(CkIndex_MeshStreamer<dtype>::flushDirect(), thisProxy));
127          }
128 };
129
130 template <class dtype>
131 MeshStreamerClient<dtype>::MeshStreamerClient() {}
132
133 template <class dtype>
134 void MeshStreamerClient<dtype>::receiveCombinedData(MeshStreamerMessage<dtype> *msg) {
135   CkError("Default implementation of receiveCombinedData should never be called\n");
136   delete msg;
137 }
138
139 template <class dtype>
140 MeshStreamer<dtype>::MeshStreamer(int totalBufferCapacity, int numRows, 
141                            int numColumns, int numPlanes, 
142                            const CProxy_MeshStreamerClient<dtype> &clientProxy) {
143   // limit total number of messages in system to totalBufferCapacity
144   //   but allocate a factor BUCKET_SIZE_FACTOR more space to take
145   //   advantage of nonuniform filling of buckets
146   // the buffers for your own column and plane are never used
147   bucketSize_ = BUCKET_SIZE_FACTOR * totalBufferCapacity / (numRows + numColumns + numPlanes - 2); 
148   totalBufferCapacity_ = totalBufferCapacity;
149   numDataItemsBuffered_ = 0; 
150   numRows_ = numRows; 
151   numColumns_ = numColumns;
152   numPlanes_ = numPlanes; 
153   numNodes_ = CkNumPes(); 
154   clientProxy_ = clientProxy; 
155
156   personalizedBuffers_ = new MeshStreamerMessage<dtype> *[numRows];
157   for (int i = 0; i < numRows; i++) {
158     personalizedBuffers_[i] = NULL; 
159   }
160
161   columnBuffers_ = new MeshStreamerMessage<dtype> *[numColumns];
162   for (int i = 0; i < numColumns; i++) {
163     columnBuffers_[i] = NULL; 
164   }
165
166   planeBuffers_ = new MeshStreamerMessage<dtype> *[numPlanes]; 
167   for (int i = 0; i < numPlanes; i++) {
168     planeBuffers_[i] = NULL; 
169   }
170
171   // determine plane, column, and row location of this node
172   myNodeIndex_ = CkMyPe();
173   planeSize_ = numRows_ * numColumns_; 
174   myPlaneIndex_ = myNodeIndex_ / planeSize_; 
175   int indexWithinPlane = myNodeIndex_ - myPlaneIndex_ * planeSize_;
176   myRowIndex_ = indexWithinPlane / numColumns_;
177   myColumnIndex_ = indexWithinPlane - myRowIndex_ * numColumns_; 
178
179 }
180
181 template <class dtype>
182 MeshStreamer<dtype>::~MeshStreamer() {
183
184   for (int i = 0; i < numRows_; i++)
185       delete personalizedBuffers_[i]; 
186
187   for (int i = 0; i < numColumns_; i++)
188       delete columnBuffers_[i]; 
189
190   for (int i = 0; i < numPlanes_; i++)
191       delete planeBuffers_[i]; 
192
193   delete[] personalizedBuffers_;
194   delete[] columnBuffers_;
195   delete[] planeBuffers_; 
196
197 }
198
199 template <class dtype>
200 void MeshStreamer<dtype>::determineLocation(const int destinationPe, int &rowIndex, 
201                                      int &columnIndex, int &planeIndex, 
202                                      MeshStreamerMessageType &msgType) {
203   
204   int nodeIndex, indexWithinPlane; 
205
206   nodeIndex = destinationPe;
207   planeIndex = nodeIndex / planeSize_; 
208   if (planeIndex != myPlaneIndex_) {
209     msgType = PlaneMessage;     
210   }
211   else {
212     indexWithinPlane = nodeIndex - planeIndex * planeSize_;
213     rowIndex = indexWithinPlane / numColumns_;
214     columnIndex = indexWithinPlane - rowIndex * numColumns_; 
215     if (columnIndex != myColumnIndex_) {
216       msgType = ColumnMessage; 
217     }
218     else {
219       msgType = PersonalizedMessage;
220     }
221   }
222
223 }
224
225 template <class dtype>
226 void MeshStreamer<dtype>::storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
227                                 const int bucketIndex, const int destinationPe, 
228                                 const int rowIndex, const int columnIndex, 
229                                 const int planeIndex, 
230                                 const MeshStreamerMessageType msgType, 
231                                 const dtype &dataItem) {
232
233   // allocate new message if necessary
234   if (messageBuffers[bucketIndex] == NULL) {
235     if (msgType == PersonalizedMessage) {
236       messageBuffers[bucketIndex] = 
237         new (0, bucketSize_) MeshStreamerMessage<dtype>();
238     }
239     else {
240       messageBuffers[bucketIndex] = 
241         new (bucketSize_, bucketSize_) MeshStreamerMessage<dtype>();
242     }
243 #ifdef DEBUG_STREAMER
244     CkAssert(messageBuffers[bucketIndex] != NULL);
245 #endif
246   }
247   
248   MeshStreamerMessage<dtype> *destinationBucket = messageBuffers[bucketIndex];
249   
250   int numBuffered = destinationBucket->addDataItem(dataItem); 
251   if (msgType != PersonalizedMessage) {
252     destinationBucket->markDestination(numBuffered-1, destinationPe);
253   }
254   numDataItemsBuffered_++;
255   // copy data into message and send if buffer is full
256   if (numBuffered == bucketSize_) {
257     int destinationIndex;
258     CProxy_MeshStreamer<dtype> thisProxy(thisgroup);
259     switch (msgType) {
260
261     case PlaneMessage:
262       destinationIndex = 
263         myNodeIndex_ + (planeIndex - myPlaneIndex_) * planeSize_;      
264       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
265       break;
266     case ColumnMessage:
267       destinationIndex = myNodeIndex_ + (columnIndex - myColumnIndex_);
268       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
269       break;
270     case PersonalizedMessage:
271       destinationIndex = myNodeIndex_ + (rowIndex - myRowIndex_) * numColumns_;
272       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);      
273       //      thisProxy[destinationIndex].receivePersonalizedData(destinationBucket);
274       break;
275     default: 
276       CkError("Incorrect MeshStreamer message type\n");
277       break;
278     }
279     messageBuffers[bucketIndex] = NULL;
280     numDataItemsBuffered_ -= numBuffered; 
281   }
282
283   if (numDataItemsBuffered_ == totalBufferCapacity_) {
284
285     flushLargestBucket(personalizedBuffers_, numRows_, myRowIndex_, numColumns_);
286     flushLargestBucket(columnBuffers_, numColumns_, myColumnIndex_, 1);
287     flushLargestBucket(planeBuffers_, numPlanes_, myPlaneIndex_, planeSize_);
288
289   }
290
291 }
292
293 template <class dtype>
294 void MeshStreamer<dtype>::insertData(const dtype &dataItem, const int destinationPe) {
295
296   int planeIndex, columnIndex, rowIndex; // location of destination
297   int indexWithinPlane; 
298
299   MeshStreamerMessageType msgType; 
300
301   determineLocation(destinationPe, rowIndex, columnIndex, planeIndex, msgType);
302
303   // determine which array of buffers is appropriate for this message
304   MeshStreamerMessage<dtype> **messageBuffers;
305   int bucketIndex; 
306
307   switch (msgType) {
308   case PlaneMessage:
309     messageBuffers = planeBuffers_; 
310     bucketIndex = planeIndex; 
311     break;
312   case ColumnMessage:
313     messageBuffers = columnBuffers_; 
314     bucketIndex = columnIndex; 
315     break;
316   case PersonalizedMessage:
317     messageBuffers = personalizedBuffers_; 
318     bucketIndex = rowIndex; 
319     break;
320   default: 
321     CkError("Unrecognized MeshStreamer message type\n");
322     break;
323   }
324
325   storeMessage(messageBuffers, bucketIndex, destinationPe, rowIndex, 
326                columnIndex, planeIndex, msgType, dataItem);
327 }
328
329 template <class dtype>
330 void MeshStreamer<dtype>::receiveAggregateData(MeshStreamerMessage<dtype> *msg) {
331
332   int rowIndex, columnIndex, planeIndex, destinationPe; 
333   MeshStreamerMessageType msgType;   
334
335   for (int i = 0; i < msg->numDataItems; i++) {
336     destinationPe = msg->destinationPes[i];
337     dtype dataItem = msg->getDataItem(i);
338     determineLocation(destinationPe, rowIndex, columnIndex, 
339                       planeIndex, msgType);
340 #ifdef DEBUG_STREAMER
341     CkAssert(planeIndex == myPlaneIndex_);
342
343     if (msgType == PersonalizedMessage) {
344       CkAssert(columnIndex == myColumnIndex_);
345     }
346 #endif    
347
348     MeshStreamerMessage<dtype> **messageBuffers;
349     int bucketIndex; 
350
351     switch (msgType) {
352     case ColumnMessage:
353       messageBuffers = columnBuffers_; 
354       bucketIndex = columnIndex; 
355       break;
356     case PersonalizedMessage:
357       messageBuffers = personalizedBuffers_; 
358       bucketIndex = rowIndex; 
359       break;
360     default: 
361       CkError("Incorrect MeshStreamer message type\n");
362       break;
363     }
364
365     storeMessage(messageBuffers, bucketIndex, destinationPe, rowIndex, 
366                  columnIndex, planeIndex, msgType, dataItem);
367     
368   }
369
370   delete msg;
371
372 }
373
374 /*
375 void MeshStreamer::receivePersonalizedData(MeshStreamerMessage *msg) {
376
377   // sort data items into messages for each core on this node
378
379   LocalMessage *localMsgs[numPesPerNode_];
380   int dataSize = bucketSize_ * dataItemSize_;
381
382   for (int i = 0; i < numPesPerNode_; i++) {
383     localMsgs[i] = new (dataSize) LocalMessage(dataItemSize_);
384   }
385
386   int destinationPe;
387   for (int i = 0; i < msg->numDataItems; i++) {
388
389     destinationPe = msg->destinationPes[i]; 
390     void *dataItem = msg->getDataItem(i);   
391     localMsgs[destinationPe % numPesPerNode_]->addDataItem(dataItem);
392
393   }
394
395   for (int i = 0; i < numPesPerNode_; i++) {
396     if (localMsgs[i]->numDataItems > 0) {
397       clientProxy_[myNodeIndex_ * numPesPerNode_ + i].receiveCombinedData(localMsgs[i]);
398     }
399     else {
400       delete localMsgs[i];
401     }
402   }
403
404   delete msg; 
405
406 }
407 */
408
409 template <class dtype>
410 void MeshStreamer<dtype>::flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
411                                       const int numBuffers, const int myIndex, 
412                                       const int dimensionFactor) {
413
414   int flushIndex, maxSize, destinationIndex;
415   MeshStreamerMessage<dtype> *destinationBucket; 
416   maxSize = 0;
417   for (int i = 0; i < numBuffers; i++) {
418     if (messageBuffers[i] != NULL && messageBuffers[i]->numDataItems > maxSize) {
419       maxSize = messageBuffers[i]->numDataItems;
420       flushIndex = i;
421     } 
422   }
423   if (maxSize > 0) {
424     destinationBucket = messageBuffers[flushIndex];
425     destinationIndex = myNodeIndex_ + (flushIndex - myIndex) * dimensionFactor;
426
427     if (destinationBucket->numDataItems < bucketSize_) {
428       // not sending the full buffer, shrink the message size
429       envelope *env = UsrToEnv(destinationBucket);
430       env->setTotalsize(env->getTotalsize() - (bucketSize_ - destinationBucket->numDataItems) * sizeof(dtype));
431     }
432     numDataItemsBuffered_ -= destinationBucket->numDataItems;
433
434     if (messageBuffers == personalizedBuffers_) {
435       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);
436     }
437     else {
438       CProxy_MeshStreamer<dtype> thisProxy(thisgroup);
439       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
440     }
441     messageBuffers[flushIndex] = NULL;
442   }
443 }
444
445 template <class dtype>
446 void MeshStreamer<dtype>::flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers)
447 {
448
449     for (int i = 0; i < numBuffers; i++) {
450        if(messageBuffers[i] == NULL)
451            continue;
452        //flush all messages in i bucket
453        numDataItemsBuffered_ -= messageBuffers[i]->numDataItems;
454        if (messageBuffers == personalizedBuffers_) {
455          int destinationPe = myNodeIndex_ + (i - myRowIndex_) * numColumns_; 
456          clientProxy_[destinationPe].receiveCombinedData(messageBuffers[i]);
457        }
458        else {
459          for (int j = 0; j < messageBuffers[i]->numDataItems; j++) {
460            MeshStreamerMessage<dtype> *directMsg = 
461              new (0, 1) MeshStreamerMessage<dtype>();
462 #ifdef DEBUG_STREAMER
463            CkAssert(directMsg != NULL);
464 #endif
465            int destinationPe = messageBuffers[i]->destinationPes[j]; 
466            dtype dataItem = messageBuffers[i]->getDataItem(j);   
467            directMsg->addDataItem(dataItem);
468            clientProxy_[destinationPe].receiveCombinedData(directMsg);
469          }
470          delete messageBuffers[i];
471        }
472        messageBuffers[i] = NULL;
473     }
474
475 }
476
477 template <class dtype>
478 void MeshStreamer<dtype>::flushDirect(){
479     flushBuckets(planeBuffers_, numPlanes_);
480     flushBuckets(columnBuffers_, numColumns_);
481     flushBuckets(personalizedBuffers_, numRows_);
482
483 #ifdef DEBUG_STREAMER
484     //CkPrintf("[%d] numDataItemsBuffered_: %d\n", CkMyPe(), numDataItemsBuffered_);
485     CkAssert(numDataItemsBuffered_ == 0); 
486 #endif
487
488     if (!userCallback_.isInvalid()) {
489         CkStartQD(userCallback_);
490         userCallback_ = CkCallback();      // nullify the current callback
491     }
492 }
493
494 #define CK_TEMPLATES_ONLY
495 #include "MeshStreamer.def.h"
496 #undef CK_TEMPLATES_ONLY
497
498 #endif