MeshStreamer: Inherit from CBase classes
[charm.git] / src / libs / ck-libs / MeshStreamer / MeshStreamer.h
1 #ifndef _MESH_STREAMER_H_
2 #define _MESH_STREAMER_H_
3
4 #include "MeshStreamer.decl.h"
5 // allocate more total buffer space than the maximum buffering limit but flush upon
6 // reaching totalBufferCapacity_
7 #define BUCKET_SIZE_FACTOR 4
8
9 //#define DEBUG_STREAMER
10
11 enum MeshStreamerMessageType {PlaneMessage, ColumnMessage, PersonalizedMessage};
12
13 class MeshLocation {
14  public:
15   int rowIndex;
16   int columnIndex;
17   int planeIndex; 
18   MeshStreamerMessageType msgType;
19 };
20
21 //#define HASH_LOCATIONS
22
23 #ifdef HASH_LOCATIONS
24 #include <map>
25 #endif
26
27 /*
28 class LocalMessage : public CMessage_LocalMessage {
29 public:
30     int numDataItems; 
31     int dataItemSize; 
32     char *data;
33
34     LocalMessage(int dataItemSizeInBytes) {
35         numDataItems = 0; 
36         dataItemSize = dataItemSizeInBytes; 
37     }
38
39     int addDataItem(void *dataItem) {
40         CmiMemcpy(&data[numDataItems * dataItemSize], dataItem, dataItemSize);
41         return ++numDataItems; 
42     } 
43
44     void *getDataItem(int index) {
45         return (void *) (&data[index * dataItemSize]);  
46     }
47
48 };
49 */
50
51 template<class dtype>
52 class MeshStreamerMessage : public CMessage_MeshStreamerMessage<dtype> {
53 public:
54     int numDataItems;
55     int *destinationPes;
56     dtype *data;
57
58     MeshStreamerMessage(): numDataItems(0) {}   
59
60     int addDataItem(const dtype &dataItem) {
61         data[numDataItems] = dataItem;
62         return ++numDataItems; 
63     }
64
65     void markDestination(const int index, const int destinationPe) {
66         destinationPes[index] = destinationPe;
67     }
68
69     dtype &getDataItem(const int index) {
70         return data[index];
71     }
72 };
73
74 template <class dtype>
75 class MeshStreamerClient : public CBase_MeshStreamerClient<dtype> {
76  public:
77      virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
78      virtual void process(dtype &data)=0; 
79 };
80
81 template <class dtype>
82 class MeshStreamer : public CBase_MeshStreamer<dtype> {
83
84 private:
85     int bucketSize_; 
86     int totalBufferCapacity_;
87     int numDataItemsBuffered_;
88
89     int numNodes_; 
90     int numRows_; 
91     int numColumns_; 
92     int numPlanes_; 
93     int planeSize_;
94
95     CProxy_MeshStreamerClient<dtype> clientProxy_;
96     MeshStreamerClient<dtype> *clientObj_;
97
98     int myNodeIndex_;
99     int myPlaneIndex_;
100     int myColumnIndex_; 
101     int myRowIndex_;
102
103     CkCallback   userCallback_;
104     int yieldFlag_;
105
106     double progressPeriodInMs_; 
107     bool isPeriodicFlushEnabled_; 
108     double timeOfLastSend_; 
109
110     MeshStreamerMessage<dtype> **personalizedBuffers_; 
111     MeshStreamerMessage<dtype> **columnBuffers_; 
112     MeshStreamerMessage<dtype> **planeBuffers_;
113
114 #ifdef HASH_LOCATIONS
115     std::map<int, MeshLocation> hashedLocations; 
116 #endif
117
118     void determineLocation(const int destinationPe, 
119                            MeshLocation &destinationCoordinates);
120
121     void storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
122                       const int bucketIndex, const int destinationPe, 
123                       const MeshLocation &destinationCoordinates, const dtype &dataItem);
124
125     void flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
126                             const int numBuffers, const int myIndex, 
127                             const int dimensionFactor);
128
129 public:
130
131     MeshStreamer(int totalBufferCapacity, int numRows, 
132                  int numColumns, int numPlanes, 
133                  const CProxy_MeshStreamerClient<dtype> &clientProxy,
134                  int yieldFlag = 0, double progressPeriodInMs = -1.0);
135     ~MeshStreamer();
136
137       // entry
138     void insertData(dtype &dataItem, const int destinationPe); 
139     void doneInserting();
140     void receiveAggregateData(MeshStreamerMessage<dtype> *msg);
141     // void receivePersonalizedData(MeshStreamerMessage<dtype> *msg);
142
143     void flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers);
144     void flushDirect();
145
146     bool isPeriodicFlushEnabled() {
147       return isPeriodicFlushEnabled_;
148     }
149       // non entry
150     void associateCallback(CkCallback &cb, bool automaticFinish = true) { 
151       userCallback_ = cb;
152       if (automaticFinish) {
153         CkStartQD(CkCallback(CkIndex_MeshStreamer<dtype>::finish(NULL), this->thisProxy));
154       }
155     }
156
157     void registerPeriodicProgressFunction();
158     void finish(CkReductionMsg *msg);
159
160     /*
161      * Flushing begins on a PE only after enablePeriodicFlushing has been invoked.
162      */
163     void enablePeriodicFlushing(){
164       isPeriodicFlushEnabled_ = true; 
165       registerPeriodicProgressFunction();
166     }
167 };
168
169 template <class dtype>
170 void MeshStreamerClient<dtype>::receiveCombinedData(MeshStreamerMessage<dtype> *msg) {
171   for (int i = 0; i < msg->numDataItems; i++) {
172      dtype data = ((dtype*)(msg->data))[i];
173      process(data);
174   }
175   delete msg;
176 }
177
178 template <class dtype>
179 MeshStreamer<dtype>::MeshStreamer(int totalBufferCapacity, int numRows, 
180                            int numColumns, int numPlanes, 
181                            const CProxy_MeshStreamerClient<dtype> &clientProxy,
182                            int yieldFlag, double progressPeriodInMs): yieldFlag_(yieldFlag) {
183   // limit total number of messages in system to totalBufferCapacity
184   //   but allocate a factor BUCKET_SIZE_FACTOR more space to take
185   //   advantage of nonuniform filling of buckets
186   // the buffers for your own column and plane are never used
187   bucketSize_ = BUCKET_SIZE_FACTOR * totalBufferCapacity / (numRows + numColumns + numPlanes - 2); 
188   totalBufferCapacity_ = totalBufferCapacity;
189   numDataItemsBuffered_ = 0; 
190   numRows_ = numRows; 
191   numColumns_ = numColumns;
192   numPlanes_ = numPlanes; 
193   numNodes_ = CkNumPes(); 
194   clientProxy_ = clientProxy; 
195   clientObj_ = ((MeshStreamerClient<dtype> *)CkLocalBranch(clientProxy_));
196   progressPeriodInMs_ = progressPeriodInMs; 
197
198   personalizedBuffers_ = new MeshStreamerMessage<dtype> *[numRows];
199   for (int i = 0; i < numRows; i++) {
200     personalizedBuffers_[i] = NULL; 
201   }
202
203   columnBuffers_ = new MeshStreamerMessage<dtype> *[numColumns];
204   for (int i = 0; i < numColumns; i++) {
205     columnBuffers_[i] = NULL; 
206   }
207
208   planeBuffers_ = new MeshStreamerMessage<dtype> *[numPlanes]; 
209   for (int i = 0; i < numPlanes; i++) {
210     planeBuffers_[i] = NULL; 
211   }
212
213   // determine plane, column, and row location of this node
214   myNodeIndex_ = CkMyPe();
215   planeSize_ = numRows_ * numColumns_; 
216   myPlaneIndex_ = myNodeIndex_ / planeSize_; 
217   int indexWithinPlane = myNodeIndex_ - myPlaneIndex_ * planeSize_;
218   myRowIndex_ = indexWithinPlane / numColumns_;
219   myColumnIndex_ = indexWithinPlane - myRowIndex_ * numColumns_; 
220
221   isPeriodicFlushEnabled_ = false; 
222 }
223
224 template <class dtype>
225 MeshStreamer<dtype>::~MeshStreamer() {
226
227   for (int i = 0; i < numRows_; i++)
228       delete personalizedBuffers_[i]; 
229
230   for (int i = 0; i < numColumns_; i++)
231       delete columnBuffers_[i]; 
232
233   for (int i = 0; i < numPlanes_; i++)
234       delete planeBuffers_[i]; 
235
236   delete[] personalizedBuffers_;
237   delete[] columnBuffers_;
238   delete[] planeBuffers_; 
239
240 }
241
242 template <class dtype>
243 void MeshStreamer<dtype>::determineLocation(const int destinationPe, 
244                                             MeshLocation &destinationCoordinates) { 
245
246   int nodeIndex, indexWithinPlane; 
247
248 #ifdef HASH_LOCATIONS
249   std::map<int, MeshLocation>::iterator it;
250   if ((it = hashedLocations.find(destinationPe)) != hashedLocations.end()) {
251     destinationCoordinates = it->second;
252     return;
253   }
254 #endif
255
256   nodeIndex = destinationPe;
257   destinationCoordinates.planeIndex = nodeIndex / planeSize_; 
258   if (destinationCoordinates.planeIndex != myPlaneIndex_) {
259     destinationCoordinates.msgType = PlaneMessage;     
260   }
261   else {
262     indexWithinPlane = 
263       nodeIndex - destinationCoordinates.planeIndex * planeSize_;
264     destinationCoordinates.rowIndex = indexWithinPlane / numColumns_;
265     destinationCoordinates.columnIndex = 
266       indexWithinPlane - destinationCoordinates.rowIndex * numColumns_; 
267     if (destinationCoordinates.columnIndex != myColumnIndex_) {
268       destinationCoordinates.msgType = ColumnMessage; 
269     }
270     else {
271       destinationCoordinates.msgType = PersonalizedMessage;
272     }
273   }
274
275 #ifdef HASH_LOCATIONS
276   hashedLocations[destinationPe] = destinationCoordinates;
277 #endif
278
279 }
280
281 template <class dtype>
282 void MeshStreamer<dtype>::storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
283                                        const int bucketIndex, const int destinationPe, 
284                                        const MeshLocation& destinationCoordinates,
285                                        const dtype &dataItem) {
286
287   // allocate new message if necessary
288   if (messageBuffers[bucketIndex] == NULL) {
289     if (destinationCoordinates.msgType == PersonalizedMessage) {
290       messageBuffers[bucketIndex] = 
291         new (0, bucketSize_) MeshStreamerMessage<dtype>();
292     }
293     else {
294       messageBuffers[bucketIndex] = 
295         new (bucketSize_, bucketSize_) MeshStreamerMessage<dtype>();
296     }
297 #ifdef DEBUG_STREAMER
298     CkAssert(messageBuffers[bucketIndex] != NULL);
299 #endif
300   }
301   
302   MeshStreamerMessage<dtype> *destinationBucket = messageBuffers[bucketIndex];
303   
304   int numBuffered = destinationBucket->addDataItem(dataItem); 
305   if (destinationCoordinates.msgType != PersonalizedMessage) {
306     destinationBucket->markDestination(numBuffered-1, destinationPe);
307   }
308   numDataItemsBuffered_++;
309   // copy data into message and send if buffer is full
310   if (numBuffered == bucketSize_) {
311     int destinationIndex;
312     switch (destinationCoordinates.msgType) {
313
314     case PlaneMessage:
315       destinationIndex = myNodeIndex_ + 
316         (destinationCoordinates.planeIndex - myPlaneIndex_) * planeSize_;  
317       this->thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
318       break;
319     case ColumnMessage:
320       destinationIndex = myNodeIndex_ + 
321         (destinationCoordinates.columnIndex - myColumnIndex_);
322       this->thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
323       break;
324     case PersonalizedMessage:
325       destinationIndex = myNodeIndex_ + 
326         (destinationCoordinates.rowIndex - myRowIndex_) * numColumns_;
327       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);      
328       //      this->thisProxy[destinationIndex].receivePersonalizedData(destinationBucket);
329       break;
330     default: 
331       CkError("Incorrect MeshStreamer message type\n");
332       break;
333     }
334     messageBuffers[bucketIndex] = NULL;
335     numDataItemsBuffered_ -= numBuffered; 
336
337     if (isPeriodicFlushEnabled_) {
338       timeOfLastSend_ = CkWallTimer();
339     }
340
341   }
342
343   if (numDataItemsBuffered_ == totalBufferCapacity_) {
344
345     flushLargestBucket(personalizedBuffers_, numRows_, myRowIndex_, numColumns_);
346     flushLargestBucket(columnBuffers_, numColumns_, myColumnIndex_, 1);
347     flushLargestBucket(planeBuffers_, numPlanes_, myPlaneIndex_, planeSize_);
348
349     if (isPeriodicFlushEnabled_) {
350       timeOfLastSend_ = CkWallTimer();
351     }
352
353   }
354
355 }
356
357 template <class dtype>
358 void MeshStreamer<dtype>::insertData(dtype &dataItem, const int destinationPe) {
359   static int count = 0;
360
361   if (destinationPe == CkMyPe()) {
362     clientObj_->process(dataItem);
363     return;
364   }
365
366   int indexWithinPlane; 
367   MeshLocation destinationCoordinates;
368
369   determineLocation(destinationPe, destinationCoordinates);
370
371   // determine which array of buffers is appropriate for this message
372   MeshStreamerMessage<dtype> **messageBuffers;
373   int bucketIndex; 
374
375   switch (destinationCoordinates.msgType) {
376   case PlaneMessage:
377     messageBuffers = planeBuffers_; 
378     bucketIndex = destinationCoordinates.planeIndex; 
379     break;
380   case ColumnMessage:
381     messageBuffers = columnBuffers_; 
382     bucketIndex = destinationCoordinates.columnIndex; 
383     break;
384   case PersonalizedMessage:
385     messageBuffers = personalizedBuffers_; 
386     bucketIndex = destinationCoordinates.rowIndex; 
387     break;
388   default: 
389     CkError("Unrecognized MeshStreamer message type\n");
390     break;
391   }
392
393   storeMessage(messageBuffers, bucketIndex, destinationPe, destinationCoordinates, 
394                dataItem);
395
396     // release control to scheduler if requested by the user, 
397     //   assume caller is threaded entry
398   if (yieldFlag_ && ++count % 1024 == 0) CthYield();
399 }
400
401 template <class dtype>
402 void MeshStreamer<dtype>::doneInserting() {
403   this->contribute(CkCallback(CkIndex_MeshStreamer<dtype>::finish(NULL), this->thisProxy));
404 }
405
406 template <class dtype>
407 void MeshStreamer<dtype>::finish(CkReductionMsg *msg) {
408
409   isPeriodicFlushEnabled_ = false; 
410   flushDirect();
411
412   if (!userCallback_.isInvalid()) {
413     CkStartQD(userCallback_);
414     userCallback_ = CkCallback();      // nullify the current callback
415   }
416
417   //  delete msg; 
418 }
419
420
421 template <class dtype>
422 void MeshStreamer<dtype>::receiveAggregateData(MeshStreamerMessage<dtype> *msg) {
423
424   int destinationPe; 
425   MeshStreamerMessageType msgType;   
426   MeshLocation destinationCoordinates;
427
428   for (int i = 0; i < msg->numDataItems; i++) {
429     destinationPe = msg->destinationPes[i];
430     dtype &dataItem = msg->getDataItem(i);
431     determineLocation(destinationPe, destinationCoordinates);
432 #ifdef DEBUG_STREAMER
433     CkAssert(destinationCoordinates.planeIndex == myPlaneIndex_);
434
435     if (destinationCoordinates.msgType == PersonalizedMessage) {
436       CkAssert(destinationCoordinates.columnIndex == myColumnIndex_);
437     }
438 #endif    
439
440     MeshStreamerMessage<dtype> **messageBuffers;
441     int bucketIndex; 
442
443     switch (destinationCoordinates.msgType) {
444     case ColumnMessage:
445       messageBuffers = columnBuffers_; 
446       bucketIndex = destinationCoordinates.columnIndex; 
447       break;
448     case PersonalizedMessage:
449       messageBuffers = personalizedBuffers_; 
450       bucketIndex = destinationCoordinates.rowIndex; 
451       break;
452     default: 
453       CkError("Incorrect MeshStreamer message type\n");
454       break;
455     }
456
457     storeMessage(messageBuffers, bucketIndex, destinationPe, 
458                  destinationCoordinates, dataItem);
459     
460   }
461
462   delete msg;
463
464 }
465
466 /*
467 void MeshStreamer::receivePersonalizedData(MeshStreamerMessage *msg) {
468
469   // sort data items into messages for each core on this node
470
471   LocalMessage *localMsgs[numPesPerNode_];
472   int dataSize = bucketSize_ * dataItemSize_;
473
474   for (int i = 0; i < numPesPerNode_; i++) {
475     localMsgs[i] = new (dataSize) LocalMessage(dataItemSize_);
476   }
477
478   int destinationPe;
479   for (int i = 0; i < msg->numDataItems; i++) {
480
481     destinationPe = msg->destinationPes[i]; 
482     void *dataItem = msg->getDataItem(i);   
483     localMsgs[destinationPe % numPesPerNode_]->addDataItem(dataItem);
484
485   }
486
487   for (int i = 0; i < numPesPerNode_; i++) {
488     if (localMsgs[i]->numDataItems > 0) {
489       clientProxy_[myNodeIndex_ * numPesPerNode_ + i].receiveCombinedData(localMsgs[i]);
490     }
491     else {
492       delete localMsgs[i];
493     }
494   }
495
496   delete msg; 
497
498 }
499 */
500
501 template <class dtype>
502 void MeshStreamer<dtype>::flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
503                                       const int numBuffers, const int myIndex, 
504                                       const int dimensionFactor) {
505
506   int flushIndex, maxSize, destinationIndex;
507   MeshStreamerMessage<dtype> *destinationBucket; 
508   maxSize = 0;
509   for (int i = 0; i < numBuffers; i++) {
510     if (messageBuffers[i] != NULL && messageBuffers[i]->numDataItems > maxSize) {
511       maxSize = messageBuffers[i]->numDataItems;
512       flushIndex = i;
513     } 
514   }
515   if (maxSize > 0) {
516     destinationBucket = messageBuffers[flushIndex];
517     destinationIndex = myNodeIndex_ + (flushIndex - myIndex) * dimensionFactor;
518
519     if (destinationBucket->numDataItems < bucketSize_) {
520       // not sending the full buffer, shrink the message size
521       envelope *env = UsrToEnv(destinationBucket);
522       env->setTotalsize(env->getTotalsize() - (bucketSize_ - destinationBucket->numDataItems) * sizeof(dtype));
523     }
524     numDataItemsBuffered_ -= destinationBucket->numDataItems;
525
526     if (messageBuffers == personalizedBuffers_) {
527       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);
528     }
529     else {
530       this->thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
531     }
532     messageBuffers[flushIndex] = NULL;
533   }
534 }
535
536 template <class dtype>
537 void MeshStreamer<dtype>::flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers)
538 {
539
540     for (int i = 0; i < numBuffers; i++) {
541        if(messageBuffers[i] == NULL)
542            continue;
543        //flush all messages in i bucket
544        numDataItemsBuffered_ -= messageBuffers[i]->numDataItems;
545        if (messageBuffers == personalizedBuffers_) {
546          int destinationPe = myNodeIndex_ + (i - myRowIndex_) * numColumns_; 
547          clientProxy_[destinationPe].receiveCombinedData(messageBuffers[i]);
548        }
549        else {
550          for (int j = 0; j < messageBuffers[i]->numDataItems; j++) {
551            MeshStreamerMessage<dtype> *directMsg = 
552              new (0, 1) MeshStreamerMessage<dtype>();
553 #ifdef DEBUG_STREAMER
554            CkAssert(directMsg != NULL);
555 #endif
556            int destinationPe = messageBuffers[i]->destinationPes[j]; 
557            dtype dataItem = messageBuffers[i]->getDataItem(j);   
558            directMsg->addDataItem(dataItem);
559            clientProxy_[destinationPe].receiveCombinedData(directMsg);
560          }
561          delete messageBuffers[i];
562        }
563        messageBuffers[i] = NULL;
564     }
565
566 }
567
568 template <class dtype>
569 void MeshStreamer<dtype>::flushDirect(){
570
571     if (!isPeriodicFlushEnabled_ || 
572         1000 * (CkWallTimer() - timeOfLastSend_) >= progressPeriodInMs_) {
573       flushBuckets(planeBuffers_, numPlanes_);
574       flushBuckets(columnBuffers_, numColumns_);
575       flushBuckets(personalizedBuffers_, numRows_);
576     }
577
578     if (isPeriodicFlushEnabled_) {
579       timeOfLastSend_ = CkWallTimer();
580     }
581
582 #ifdef DEBUG_STREAMER
583     //CkPrintf("[%d] numDataItemsBuffered_: %d\n", CkMyPe(), numDataItemsBuffered_);
584     CkAssert(numDataItemsBuffered_ == 0); 
585 #endif
586
587 }
588
589 template <class dtype>
590 void periodicProgressFunction(void *MeshStreamerObj, double time) {
591
592   MeshStreamer<dtype> *properObj = 
593     static_cast<MeshStreamer<dtype>*>(MeshStreamerObj); 
594
595   if (properObj->isPeriodicFlushEnabled()) {
596     properObj->flushDirect();
597     properObj->registerPeriodicProgressFunction();
598   }
599 }
600
601 template <class dtype>
602 void MeshStreamer<dtype>::registerPeriodicProgressFunction() {
603   CcdCallFnAfter(periodicProgressFunction<dtype>, (void *) this, progressPeriodInMs_); 
604 }
605
606
607 #define CK_TEMPLATES_ONLY
608 #include "MeshStreamer.def.h"
609 #undef CK_TEMPLATES_ONLY
610
611 #endif