MeshStreamer: Added the option to let users specify when each group member
[charm.git] / src / libs / ck-libs / MeshStreamer / MeshStreamer.h
1 #ifndef _MESH_STREAMER_H_
2 #define _MESH_STREAMER_H_
3
4 #include "MeshStreamer.decl.h"
5
6 // allocate more total buffer space than the maximum buffering limit but flush upon
7 // reaching totalBufferCapacity_
8 #define BUCKET_SIZE_FACTOR 4
9
10 //#define DEBUG_STREAMER 1
11
12 enum MeshStreamerMessageType {PlaneMessage, ColumnMessage, PersonalizedMessage};
13
14 /*
15 class LocalMessage : public CMessage_LocalMessage {
16 public:
17     int numDataItems; 
18     int dataItemSize; 
19     char *data;
20
21     LocalMessage(int dataItemSizeInBytes) {
22         numDataItems = 0; 
23         dataItemSize = dataItemSizeInBytes; 
24     }
25
26     int addDataItem(void *dataItem) {
27         CmiMemcpy(&data[numDataItems * dataItemSize], dataItem, dataItemSize);
28         return ++numDataItems; 
29     } 
30
31     void *getDataItem(int index) {
32         return (void *) (&data[index * dataItemSize]);  
33     }
34
35 };
36 */
37
38 template<class dtype>
39 class MeshStreamerMessage : public CMessage_MeshStreamerMessage<dtype> {
40 public:
41     int numDataItems;
42     int *destinationPes;
43     dtype *data;
44
45     MeshStreamerMessage(): numDataItems(0) {}   
46
47     int addDataItem(const dtype &dataItem) {
48         data[numDataItems] = dataItem;
49         return ++numDataItems; 
50     }
51
52     void markDestination(const int index, const int destinationPe) {
53         destinationPes[index] = destinationPe;
54     }
55
56     dtype getDataItem(const int index) {
57         return data[index];
58     }
59 };
60
61 template <class dtype>
62 class MeshStreamerClient : public Group {
63  public:
64      virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
65      virtual void process(dtype data)=0; 
66 };
67
68 template <class dtype>
69 class MeshStreamer : public Group {
70
71 private:
72     int bucketSize_; 
73     int totalBufferCapacity_;
74     int numDataItemsBuffered_;
75
76     int numNodes_; 
77     int numRows_; 
78     int numColumns_; 
79     int numPlanes_; 
80     int planeSize_;
81
82     CProxy_MeshStreamerClient<dtype> clientProxy_;
83     MeshStreamerClient<dtype> *clientObj_;
84
85     int myNodeIndex_;
86     int myPlaneIndex_;
87     int myColumnIndex_; 
88     int myRowIndex_;
89
90     CkCallback   userCallback_;
91     int yieldFlag_;
92
93     int progressPeriodInMs_; 
94     bool isPeriodicFlushEnabled_; 
95     double timeOfLastSend_; 
96
97     MeshStreamerMessage<dtype> **personalizedBuffers_; 
98     MeshStreamerMessage<dtype> **columnBuffers_; 
99     MeshStreamerMessage<dtype> **planeBuffers_;
100
101     void determineLocation(const int destinationPe, int &row, int &column, 
102         int &plane, MeshStreamerMessageType &msgType);
103
104     void storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
105         const int bucketIndex, const int destinationPe, 
106         const int rowIndex, const int columnIndex, 
107         const int planeIndex,
108         const MeshStreamerMessageType msgType, const dtype &dataItem);
109
110     void flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
111                             const int numBuffers, const int myIndex, 
112                             const int dimensionFactor);
113
114 public:
115
116     MeshStreamer(int totalBufferCapacity, int numRows, 
117                  int numColumns, int numPlanes, 
118                  const CProxy_MeshStreamerClient<dtype> &clientProxy,
119                  int yieldFlag = 0, int progressPeriodInMs = -1);
120     ~MeshStreamer();
121
122       // entry
123     void insertData(const dtype &dataItem, const int destinationPe); 
124     void doneInserting();
125     void receiveAggregateData(MeshStreamerMessage<dtype> *msg);
126     // void receivePersonalizedData(MeshStreamerMessage<dtype> *msg);
127
128     void flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers);
129     void flushDirect();
130
131     bool isPeriodicFlushEnabled() {
132       return isPeriodicFlushEnabled_;
133     }
134       // non entry
135     void associateCallback(CkCallback &cb, bool automaticFinish = true) { 
136       userCallback_ = cb;
137       if (automaticFinish) {
138         CkStartQD(CkCallback(CkIndex_MeshStreamer<dtype>::finish(NULL), thisProxy));
139       }
140     }
141
142     void registerPeriodicProgressFunction();
143     void finish(CkReductionMsg *msg);
144 };
145
146 template <class dtype>
147 void MeshStreamerClient<dtype>::receiveCombinedData(MeshStreamerMessage<dtype> *msg) {
148   for (int i = 0; i < msg->numDataItems; i++) {
149      dtype data = ((dtype*)(msg->data))[i];
150      process(data);
151   }
152   delete msg;
153 }
154
155 template <class dtype>
156 MeshStreamer<dtype>::MeshStreamer(int totalBufferCapacity, int numRows, 
157                            int numColumns, int numPlanes, 
158                            const CProxy_MeshStreamerClient<dtype> &clientProxy,
159                            int yieldFlag, int progressPeriodInMs): yieldFlag_(yieldFlag) {
160   // limit total number of messages in system to totalBufferCapacity
161   //   but allocate a factor BUCKET_SIZE_FACTOR more space to take
162   //   advantage of nonuniform filling of buckets
163   // the buffers for your own column and plane are never used
164   bucketSize_ = BUCKET_SIZE_FACTOR * totalBufferCapacity / (numRows + numColumns + numPlanes - 2); 
165   totalBufferCapacity_ = totalBufferCapacity;
166   numDataItemsBuffered_ = 0; 
167   numRows_ = numRows; 
168   numColumns_ = numColumns;
169   numPlanes_ = numPlanes; 
170   numNodes_ = CkNumPes(); 
171   clientProxy_ = clientProxy; 
172   clientObj_ = ((MeshStreamerClient<dtype> *)CkLocalBranch(clientProxy_));
173   progressPeriodInMs_ = progressPeriodInMs; 
174
175   personalizedBuffers_ = new MeshStreamerMessage<dtype> *[numRows];
176   for (int i = 0; i < numRows; i++) {
177     personalizedBuffers_[i] = NULL; 
178   }
179
180   columnBuffers_ = new MeshStreamerMessage<dtype> *[numColumns];
181   for (int i = 0; i < numColumns; i++) {
182     columnBuffers_[i] = NULL; 
183   }
184
185   planeBuffers_ = new MeshStreamerMessage<dtype> *[numPlanes]; 
186   for (int i = 0; i < numPlanes; i++) {
187     planeBuffers_[i] = NULL; 
188   }
189
190   // determine plane, column, and row location of this node
191   myNodeIndex_ = CkMyPe();
192   planeSize_ = numRows_ * numColumns_; 
193   myPlaneIndex_ = myNodeIndex_ / planeSize_; 
194   int indexWithinPlane = myNodeIndex_ - myPlaneIndex_ * planeSize_;
195   myRowIndex_ = indexWithinPlane / numColumns_;
196   myColumnIndex_ = indexWithinPlane - myRowIndex_ * numColumns_; 
197
198   if (progressPeriodInMs_ > 0) {
199     isPeriodicFlushEnabled_ = true; 
200     registerPeriodicProgressFunction();
201   }
202   else {
203     isPeriodicFlushEnabled_ = false; 
204   }
205
206 }
207
208 template <class dtype>
209 MeshStreamer<dtype>::~MeshStreamer() {
210
211   for (int i = 0; i < numRows_; i++)
212       delete personalizedBuffers_[i]; 
213
214   for (int i = 0; i < numColumns_; i++)
215       delete columnBuffers_[i]; 
216
217   for (int i = 0; i < numPlanes_; i++)
218       delete planeBuffers_[i]; 
219
220   delete[] personalizedBuffers_;
221   delete[] columnBuffers_;
222   delete[] planeBuffers_; 
223
224 }
225
226 template <class dtype>
227 void MeshStreamer<dtype>::determineLocation(const int destinationPe, int &rowIndex, 
228                                      int &columnIndex, int &planeIndex, 
229                                      MeshStreamerMessageType &msgType) {
230   
231   int nodeIndex, indexWithinPlane; 
232
233   nodeIndex = destinationPe;
234   planeIndex = nodeIndex / planeSize_; 
235   if (planeIndex != myPlaneIndex_) {
236     msgType = PlaneMessage;     
237   }
238   else {
239     indexWithinPlane = nodeIndex - planeIndex * planeSize_;
240     rowIndex = indexWithinPlane / numColumns_;
241     columnIndex = indexWithinPlane - rowIndex * numColumns_; 
242     if (columnIndex != myColumnIndex_) {
243       msgType = ColumnMessage; 
244     }
245     else {
246       msgType = PersonalizedMessage;
247     }
248   }
249
250 }
251
252 template <class dtype>
253 void MeshStreamer<dtype>::storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
254                                 const int bucketIndex, const int destinationPe, 
255                                 const int rowIndex, const int columnIndex, 
256                                 const int planeIndex, 
257                                 const MeshStreamerMessageType msgType, 
258                                 const dtype &dataItem) {
259
260   // allocate new message if necessary
261   if (messageBuffers[bucketIndex] == NULL) {
262     if (msgType == PersonalizedMessage) {
263       messageBuffers[bucketIndex] = 
264         new (0, bucketSize_) MeshStreamerMessage<dtype>();
265     }
266     else {
267       messageBuffers[bucketIndex] = 
268         new (bucketSize_, bucketSize_) MeshStreamerMessage<dtype>();
269     }
270 #ifdef DEBUG_STREAMER
271     CkAssert(messageBuffers[bucketIndex] != NULL);
272 #endif
273   }
274   
275   MeshStreamerMessage<dtype> *destinationBucket = messageBuffers[bucketIndex];
276   
277   int numBuffered = destinationBucket->addDataItem(dataItem); 
278   if (msgType != PersonalizedMessage) {
279     destinationBucket->markDestination(numBuffered-1, destinationPe);
280   }
281   numDataItemsBuffered_++;
282   // copy data into message and send if buffer is full
283   if (numBuffered == bucketSize_) {
284     int destinationIndex;
285     CProxy_MeshStreamer<dtype> thisProxy(thisgroup);
286     switch (msgType) {
287
288     case PlaneMessage:
289       destinationIndex = 
290         myNodeIndex_ + (planeIndex - myPlaneIndex_) * planeSize_;      
291       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
292       break;
293     case ColumnMessage:
294       destinationIndex = myNodeIndex_ + (columnIndex - myColumnIndex_);
295       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
296       break;
297     case PersonalizedMessage:
298       destinationIndex = myNodeIndex_ + (rowIndex - myRowIndex_) * numColumns_;
299       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);      
300       //      thisProxy[destinationIndex].receivePersonalizedData(destinationBucket);
301       break;
302     default: 
303       CkError("Incorrect MeshStreamer message type\n");
304       break;
305     }
306     messageBuffers[bucketIndex] = NULL;
307     numDataItemsBuffered_ -= numBuffered; 
308
309     if (isPeriodicFlushEnabled_) {
310       timeOfLastSend_ = CkWallTimer();
311     }
312
313   }
314
315   if (numDataItemsBuffered_ == totalBufferCapacity_) {
316
317     flushLargestBucket(personalizedBuffers_, numRows_, myRowIndex_, numColumns_);
318     flushLargestBucket(columnBuffers_, numColumns_, myColumnIndex_, 1);
319     flushLargestBucket(planeBuffers_, numPlanes_, myPlaneIndex_, planeSize_);
320
321     if (isPeriodicFlushEnabled_) {
322       timeOfLastSend_ = CkWallTimer();
323     }
324
325   }
326
327 }
328
329 template <class dtype>
330 void MeshStreamer<dtype>::insertData(const dtype &dataItem, const int destinationPe) {
331   static int count = 0;
332
333   if (destinationPe == CkMyPe()) {
334     clientObj_->process(dataItem);
335     return;
336   }
337
338   int planeIndex, columnIndex, rowIndex; // location of destination
339   int indexWithinPlane; 
340   MeshStreamerMessageType msgType; 
341
342   determineLocation(destinationPe, rowIndex, columnIndex, planeIndex, msgType);
343
344   // determine which array of buffers is appropriate for this message
345   MeshStreamerMessage<dtype> **messageBuffers;
346   int bucketIndex; 
347
348   switch (msgType) {
349   case PlaneMessage:
350     messageBuffers = planeBuffers_; 
351     bucketIndex = planeIndex; 
352     break;
353   case ColumnMessage:
354     messageBuffers = columnBuffers_; 
355     bucketIndex = columnIndex; 
356     break;
357   case PersonalizedMessage:
358     messageBuffers = personalizedBuffers_; 
359     bucketIndex = rowIndex; 
360     break;
361   default: 
362     CkError("Unrecognized MeshStreamer message type\n");
363     break;
364   }
365
366   storeMessage(messageBuffers, bucketIndex, destinationPe, rowIndex, 
367                columnIndex, planeIndex, msgType, dataItem);
368
369     // release control to scheduler if requested by the user, 
370     //   assume caller is threaded entry
371   if (yieldFlag_ && ++count % 1024 == 0) CthYield();
372 }
373
374 template <class dtype>
375 void MeshStreamer<dtype>::doneInserting() {
376   contribute(CkCallback(CkIndex_MeshStreamer<dtype>::finish(NULL)), thisProxy);
377 }
378
379 template <class dtype>
380 void MeshStreamer<dtype>::finish(CkReductionMsg *msg) {
381
382   isPeriodicFlushEnabled_ = false; 
383   flushDirect();
384
385   if (!userCallback_.isInvalid()) {
386     CkStartQD(userCallback_);
387     userCallback_ = CkCallback();      // nullify the current callback
388   }
389
390   delete msg; 
391 }
392
393
394 template <class dtype>
395 void MeshStreamer<dtype>::receiveAggregateData(MeshStreamerMessage<dtype> *msg) {
396
397   int rowIndex, columnIndex, planeIndex, destinationPe; 
398   MeshStreamerMessageType msgType;   
399
400   for (int i = 0; i < msg->numDataItems; i++) {
401     destinationPe = msg->destinationPes[i];
402     dtype dataItem = msg->getDataItem(i);
403     determineLocation(destinationPe, rowIndex, columnIndex, 
404                       planeIndex, msgType);
405 #ifdef DEBUG_STREAMER
406     CkAssert(planeIndex == myPlaneIndex_);
407
408     if (msgType == PersonalizedMessage) {
409       CkAssert(columnIndex == myColumnIndex_);
410     }
411 #endif    
412
413     MeshStreamerMessage<dtype> **messageBuffers;
414     int bucketIndex; 
415
416     switch (msgType) {
417     case ColumnMessage:
418       messageBuffers = columnBuffers_; 
419       bucketIndex = columnIndex; 
420       break;
421     case PersonalizedMessage:
422       messageBuffers = personalizedBuffers_; 
423       bucketIndex = rowIndex; 
424       break;
425     default: 
426       CkError("Incorrect MeshStreamer message type\n");
427       break;
428     }
429
430     storeMessage(messageBuffers, bucketIndex, destinationPe, rowIndex, 
431                  columnIndex, planeIndex, msgType, dataItem);
432     
433   }
434
435   delete msg;
436
437 }
438
439 /*
440 void MeshStreamer::receivePersonalizedData(MeshStreamerMessage *msg) {
441
442   // sort data items into messages for each core on this node
443
444   LocalMessage *localMsgs[numPesPerNode_];
445   int dataSize = bucketSize_ * dataItemSize_;
446
447   for (int i = 0; i < numPesPerNode_; i++) {
448     localMsgs[i] = new (dataSize) LocalMessage(dataItemSize_);
449   }
450
451   int destinationPe;
452   for (int i = 0; i < msg->numDataItems; i++) {
453
454     destinationPe = msg->destinationPes[i]; 
455     void *dataItem = msg->getDataItem(i);   
456     localMsgs[destinationPe % numPesPerNode_]->addDataItem(dataItem);
457
458   }
459
460   for (int i = 0; i < numPesPerNode_; i++) {
461     if (localMsgs[i]->numDataItems > 0) {
462       clientProxy_[myNodeIndex_ * numPesPerNode_ + i].receiveCombinedData(localMsgs[i]);
463     }
464     else {
465       delete localMsgs[i];
466     }
467   }
468
469   delete msg; 
470
471 }
472 */
473
474 template <class dtype>
475 void MeshStreamer<dtype>::flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
476                                       const int numBuffers, const int myIndex, 
477                                       const int dimensionFactor) {
478
479   int flushIndex, maxSize, destinationIndex;
480   MeshStreamerMessage<dtype> *destinationBucket; 
481   maxSize = 0;
482   for (int i = 0; i < numBuffers; i++) {
483     if (messageBuffers[i] != NULL && messageBuffers[i]->numDataItems > maxSize) {
484       maxSize = messageBuffers[i]->numDataItems;
485       flushIndex = i;
486     } 
487   }
488   if (maxSize > 0) {
489     destinationBucket = messageBuffers[flushIndex];
490     destinationIndex = myNodeIndex_ + (flushIndex - myIndex) * dimensionFactor;
491
492     if (destinationBucket->numDataItems < bucketSize_) {
493       // not sending the full buffer, shrink the message size
494       envelope *env = UsrToEnv(destinationBucket);
495       env->setTotalsize(env->getTotalsize() - (bucketSize_ - destinationBucket->numDataItems) * sizeof(dtype));
496     }
497     numDataItemsBuffered_ -= destinationBucket->numDataItems;
498
499     if (messageBuffers == personalizedBuffers_) {
500       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);
501     }
502     else {
503       CProxy_MeshStreamer<dtype> thisProxy(thisgroup);
504       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
505     }
506     messageBuffers[flushIndex] = NULL;
507   }
508 }
509
510 template <class dtype>
511 void MeshStreamer<dtype>::flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers)
512 {
513
514     for (int i = 0; i < numBuffers; i++) {
515        if(messageBuffers[i] == NULL)
516            continue;
517        //flush all messages in i bucket
518        numDataItemsBuffered_ -= messageBuffers[i]->numDataItems;
519        if (messageBuffers == personalizedBuffers_) {
520          int destinationPe = myNodeIndex_ + (i - myRowIndex_) * numColumns_; 
521          clientProxy_[destinationPe].receiveCombinedData(messageBuffers[i]);
522        }
523        else {
524          for (int j = 0; j < messageBuffers[i]->numDataItems; j++) {
525            MeshStreamerMessage<dtype> *directMsg = 
526              new (0, 1) MeshStreamerMessage<dtype>();
527 #ifdef DEBUG_STREAMER
528            CkAssert(directMsg != NULL);
529 #endif
530            int destinationPe = messageBuffers[i]->destinationPes[j]; 
531            dtype dataItem = messageBuffers[i]->getDataItem(j);   
532            directMsg->addDataItem(dataItem);
533            clientProxy_[destinationPe].receiveCombinedData(directMsg);
534          }
535          delete messageBuffers[i];
536        }
537        messageBuffers[i] = NULL;
538     }
539
540 }
541
542 template <class dtype>
543 void MeshStreamer<dtype>::flushDirect(){
544
545     if (!isPeriodicFlushEnabled_ || 
546         CkWallTimer() - timeOfLastSend_ >= progressPeriodInMs_) {
547       flushBuckets(planeBuffers_, numPlanes_);
548       flushBuckets(columnBuffers_, numColumns_);
549       flushBuckets(personalizedBuffers_, numRows_);
550     }
551
552     if (isPeriodicFlushEnabled_) {
553       timeOfLastSend_ = CkWallTimer();
554     }
555
556 #ifdef DEBUG_STREAMER
557     //CkPrintf("[%d] numDataItemsBuffered_: %d\n", CkMyPe(), numDataItemsBuffered_);
558     CkAssert(numDataItemsBuffered_ == 0); 
559 #endif
560
561 }
562
563 template <class dtype>
564 void periodicProgressFunction(void *MeshStreamerObj, double time) {
565
566   MeshStreamer<dtype> *properObj = 
567     static_cast<MeshStreamer<dtype>*>(MeshStreamerObj); 
568
569   if (properObj->isPeriodicFlushEnabled()) {
570     properObj->flushDirect();
571     properObj->registerPeriodicProgressFunction();
572   }
573 }
574
575 template <class dtype>
576 void MeshStreamer<dtype>::registerPeriodicProgressFunction() {
577   CcdCallFnAfter(periodicProgressFunction<dtype>, (void *) this, progressPeriodInMs_); 
578 }
579
580
581 #define CK_TEMPLATES_ONLY
582 #include "MeshStreamer.def.h"
583 #undef CK_TEMPLATES_ONLY
584
585 #endif