70029953d2b06b02d844c66cb067c126bc27b8db
[charm.git] / src / libs / ck-libs / MeshStreamer / MeshStreamer.h
1 #ifndef _MESH_STREAMER_H_
2 #define _MESH_STREAMER_H_
3
4 #include "MeshStreamer.decl.h"
5 // allocate more total buffer space than the maximum buffering limit but flush upon
6 // reaching totalBufferCapacity_
7 #define BUCKET_SIZE_FACTOR 4
8
9 //#define DEBUG_STREAMER
10
11 enum MeshStreamerMessageType {PlaneMessage, ColumnMessage, PersonalizedMessage};
12
13 class MeshLocation {
14  public:
15   int rowIndex;
16   int columnIndex;
17   int planeIndex; 
18   MeshStreamerMessageType msgType;
19 };
20
21 //#define HASH_LOCATIONS
22
23 #ifdef HASH_LOCATIONS
24 #include <map>
25 #endif
26
27 /*
28 class LocalMessage : public CMessage_LocalMessage {
29 public:
30     int numDataItems; 
31     int dataItemSize; 
32     char *data;
33
34     LocalMessage(int dataItemSizeInBytes) {
35         numDataItems = 0; 
36         dataItemSize = dataItemSizeInBytes; 
37     }
38
39     int addDataItem(void *dataItem) {
40         CmiMemcpy(&data[numDataItems * dataItemSize], dataItem, dataItemSize);
41         return ++numDataItems; 
42     } 
43
44     void *getDataItem(int index) {
45         return (void *) (&data[index * dataItemSize]);  
46     }
47
48 };
49 */
50
51 template<class dtype>
52 class MeshStreamerMessage : public CMessage_MeshStreamerMessage<dtype> {
53 public:
54     int numDataItems;
55     int *destinationPes;
56     dtype *data;
57
58     MeshStreamerMessage(): numDataItems(0) {}   
59
60     int addDataItem(const dtype &dataItem) {
61         data[numDataItems] = dataItem;
62         return ++numDataItems; 
63     }
64
65     void markDestination(const int index, const int destinationPe) {
66         destinationPes[index] = destinationPe;
67     }
68
69     dtype &getDataItem(const int index) {
70         return data[index];
71     }
72 };
73
74 template <class dtype>
75 class MeshStreamerClient : public Group {
76  public:
77      virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
78      virtual void process(dtype &data)=0; 
79 };
80
81 template <class dtype>
82 class MeshStreamer : public Group {
83
84 private:
85     int bucketSize_; 
86     int totalBufferCapacity_;
87     int numDataItemsBuffered_;
88
89     int numNodes_; 
90     int numRows_; 
91     int numColumns_; 
92     int numPlanes_; 
93     int planeSize_;
94
95     CProxy_MeshStreamerClient<dtype> clientProxy_;
96     MeshStreamerClient<dtype> *clientObj_;
97
98     int myNodeIndex_;
99     int myPlaneIndex_;
100     int myColumnIndex_; 
101     int myRowIndex_;
102
103     CkCallback   userCallback_;
104     int yieldFlag_;
105
106     double progressPeriodInMs_; 
107     bool isPeriodicFlushEnabled_; 
108     double timeOfLastSend_; 
109
110     MeshStreamerMessage<dtype> **personalizedBuffers_; 
111     MeshStreamerMessage<dtype> **columnBuffers_; 
112     MeshStreamerMessage<dtype> **planeBuffers_;
113
114 #ifdef HASH_LOCATIONS
115     std::map<int, MeshLocation> hashedLocations; 
116 #endif
117
118     void determineLocation(const int destinationPe, 
119                            MeshLocation &destinationCoordinates);
120
121     void storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
122                       const int bucketIndex, const int destinationPe, 
123                       const MeshLocation &destinationCoordinates, const dtype &dataItem);
124
125     void flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
126                             const int numBuffers, const int myIndex, 
127                             const int dimensionFactor);
128
129 public:
130
131     MeshStreamer(int totalBufferCapacity, int numRows, 
132                  int numColumns, int numPlanes, 
133                  const CProxy_MeshStreamerClient<dtype> &clientProxy,
134                  int yieldFlag = 0, double progressPeriodInMs = -1.0);
135     ~MeshStreamer();
136
137       // entry
138     void insertData(dtype &dataItem, const int destinationPe); 
139     void doneInserting();
140     void receiveAggregateData(MeshStreamerMessage<dtype> *msg);
141     // void receivePersonalizedData(MeshStreamerMessage<dtype> *msg);
142
143     void flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers);
144     void flushDirect();
145
146     bool isPeriodicFlushEnabled() {
147       return isPeriodicFlushEnabled_;
148     }
149       // non entry
150     void associateCallback(CkCallback &cb, bool automaticFinish = true) { 
151       userCallback_ = cb;
152       if (automaticFinish) {
153         CkStartQD(CkCallback(CkIndex_MeshStreamer<dtype>::finish(NULL), thisProxy));
154       }
155     }
156
157     void registerPeriodicProgressFunction();
158     void finish(CkReductionMsg *msg);
159
160     /*
161      * Flushing begins on a PE only after enablePeriodicFlushing has been invoked.
162      */
163     void enablePeriodicFlushing(){
164       isPeriodicFlushEnabled_ = true; 
165       registerPeriodicProgressFunction();
166     }
167 };
168
169 template <class dtype>
170 void MeshStreamerClient<dtype>::receiveCombinedData(MeshStreamerMessage<dtype> *msg) {
171   for (int i = 0; i < msg->numDataItems; i++) {
172      dtype data = ((dtype*)(msg->data))[i];
173      process(data);
174   }
175   delete msg;
176 }
177
178 template <class dtype>
179 MeshStreamer<dtype>::MeshStreamer(int totalBufferCapacity, int numRows, 
180                            int numColumns, int numPlanes, 
181                            const CProxy_MeshStreamerClient<dtype> &clientProxy,
182                            int yieldFlag, double progressPeriodInMs): yieldFlag_(yieldFlag) {
183   // limit total number of messages in system to totalBufferCapacity
184   //   but allocate a factor BUCKET_SIZE_FACTOR more space to take
185   //   advantage of nonuniform filling of buckets
186   // the buffers for your own column and plane are never used
187   bucketSize_ = BUCKET_SIZE_FACTOR * totalBufferCapacity / (numRows + numColumns + numPlanes - 2); 
188   totalBufferCapacity_ = totalBufferCapacity;
189   numDataItemsBuffered_ = 0; 
190   numRows_ = numRows; 
191   numColumns_ = numColumns;
192   numPlanes_ = numPlanes; 
193   numNodes_ = CkNumPes(); 
194   clientProxy_ = clientProxy; 
195   clientObj_ = ((MeshStreamerClient<dtype> *)CkLocalBranch(clientProxy_));
196   progressPeriodInMs_ = progressPeriodInMs; 
197
198   personalizedBuffers_ = new MeshStreamerMessage<dtype> *[numRows];
199   for (int i = 0; i < numRows; i++) {
200     personalizedBuffers_[i] = NULL; 
201   }
202
203   columnBuffers_ = new MeshStreamerMessage<dtype> *[numColumns];
204   for (int i = 0; i < numColumns; i++) {
205     columnBuffers_[i] = NULL; 
206   }
207
208   planeBuffers_ = new MeshStreamerMessage<dtype> *[numPlanes]; 
209   for (int i = 0; i < numPlanes; i++) {
210     planeBuffers_[i] = NULL; 
211   }
212
213   // determine plane, column, and row location of this node
214   myNodeIndex_ = CkMyPe();
215   planeSize_ = numRows_ * numColumns_; 
216   myPlaneIndex_ = myNodeIndex_ / planeSize_; 
217   int indexWithinPlane = myNodeIndex_ - myPlaneIndex_ * planeSize_;
218   myRowIndex_ = indexWithinPlane / numColumns_;
219   myColumnIndex_ = indexWithinPlane - myRowIndex_ * numColumns_; 
220
221   isPeriodicFlushEnabled_ = false; 
222 }
223
224 template <class dtype>
225 MeshStreamer<dtype>::~MeshStreamer() {
226
227   for (int i = 0; i < numRows_; i++)
228       delete personalizedBuffers_[i]; 
229
230   for (int i = 0; i < numColumns_; i++)
231       delete columnBuffers_[i]; 
232
233   for (int i = 0; i < numPlanes_; i++)
234       delete planeBuffers_[i]; 
235
236   delete[] personalizedBuffers_;
237   delete[] columnBuffers_;
238   delete[] planeBuffers_; 
239
240 }
241
242 template <class dtype>
243 void MeshStreamer<dtype>::determineLocation(const int destinationPe, 
244                                             MeshLocation &destinationCoordinates) { 
245
246   int nodeIndex, indexWithinPlane; 
247
248 #ifdef HASH_LOCATIONS
249   std::map<int, MeshLocation>::iterator it;
250   if ((it = hashedLocations.find(destinationPe)) != hashedLocations.end()) {
251     destinationCoordinates = it->second;
252     return;
253   }
254 #endif
255
256   nodeIndex = destinationPe;
257   destinationCoordinates.planeIndex = nodeIndex / planeSize_; 
258   if (destinationCoordinates.planeIndex != myPlaneIndex_) {
259     destinationCoordinates.msgType = PlaneMessage;     
260   }
261   else {
262     indexWithinPlane = 
263       nodeIndex - destinationCoordinates.planeIndex * planeSize_;
264     destinationCoordinates.rowIndex = indexWithinPlane / numColumns_;
265     destinationCoordinates.columnIndex = 
266       indexWithinPlane - destinationCoordinates.rowIndex * numColumns_; 
267     if (destinationCoordinates.columnIndex != myColumnIndex_) {
268       destinationCoordinates.msgType = ColumnMessage; 
269     }
270     else {
271       destinationCoordinates.msgType = PersonalizedMessage;
272     }
273   }
274
275 #ifdef HASH_LOCATIONS
276   hashedLocations[destinationPe] = destinationCoordinates;
277 #endif
278
279 }
280
281 template <class dtype>
282 void MeshStreamer<dtype>::storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
283                                        const int bucketIndex, const int destinationPe, 
284                                        const MeshLocation& destinationCoordinates,
285                                        const dtype &dataItem) {
286
287   // allocate new message if necessary
288   if (messageBuffers[bucketIndex] == NULL) {
289     if (destinationCoordinates.msgType == PersonalizedMessage) {
290       messageBuffers[bucketIndex] = 
291         new (0, bucketSize_) MeshStreamerMessage<dtype>();
292     }
293     else {
294       messageBuffers[bucketIndex] = 
295         new (bucketSize_, bucketSize_) MeshStreamerMessage<dtype>();
296     }
297 #ifdef DEBUG_STREAMER
298     CkAssert(messageBuffers[bucketIndex] != NULL);
299 #endif
300   }
301   
302   MeshStreamerMessage<dtype> *destinationBucket = messageBuffers[bucketIndex];
303   
304   int numBuffered = destinationBucket->addDataItem(dataItem); 
305   if (destinationCoordinates.msgType != PersonalizedMessage) {
306     destinationBucket->markDestination(numBuffered-1, destinationPe);
307   }
308   numDataItemsBuffered_++;
309   // copy data into message and send if buffer is full
310   if (numBuffered == bucketSize_) {
311     int destinationIndex;
312     CProxy_MeshStreamer<dtype> thisProxy(thisgroup);
313     switch (destinationCoordinates.msgType) {
314
315     case PlaneMessage:
316       destinationIndex = myNodeIndex_ + 
317         (destinationCoordinates.planeIndex - myPlaneIndex_) * planeSize_;  
318       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
319       break;
320     case ColumnMessage:
321       destinationIndex = myNodeIndex_ + 
322         (destinationCoordinates.columnIndex - myColumnIndex_);
323       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
324       break;
325     case PersonalizedMessage:
326       destinationIndex = myNodeIndex_ + 
327         (destinationCoordinates.rowIndex - myRowIndex_) * numColumns_;
328       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);      
329       //      thisProxy[destinationIndex].receivePersonalizedData(destinationBucket);
330       break;
331     default: 
332       CkError("Incorrect MeshStreamer message type\n");
333       break;
334     }
335     messageBuffers[bucketIndex] = NULL;
336     numDataItemsBuffered_ -= numBuffered; 
337
338     if (isPeriodicFlushEnabled_) {
339       timeOfLastSend_ = CkWallTimer();
340     }
341
342   }
343
344   if (numDataItemsBuffered_ == totalBufferCapacity_) {
345
346     flushLargestBucket(personalizedBuffers_, numRows_, myRowIndex_, numColumns_);
347     flushLargestBucket(columnBuffers_, numColumns_, myColumnIndex_, 1);
348     flushLargestBucket(planeBuffers_, numPlanes_, myPlaneIndex_, planeSize_);
349
350     if (isPeriodicFlushEnabled_) {
351       timeOfLastSend_ = CkWallTimer();
352     }
353
354   }
355
356 }
357
358 template <class dtype>
359 void MeshStreamer<dtype>::insertData(dtype &dataItem, const int destinationPe) {
360   static int count = 0;
361
362   if (destinationPe == CkMyPe()) {
363     clientObj_->process(dataItem);
364     return;
365   }
366
367   int indexWithinPlane; 
368   MeshLocation destinationCoordinates;
369
370   determineLocation(destinationPe, destinationCoordinates);
371
372   // determine which array of buffers is appropriate for this message
373   MeshStreamerMessage<dtype> **messageBuffers;
374   int bucketIndex; 
375
376   switch (destinationCoordinates.msgType) {
377   case PlaneMessage:
378     messageBuffers = planeBuffers_; 
379     bucketIndex = destinationCoordinates.planeIndex; 
380     break;
381   case ColumnMessage:
382     messageBuffers = columnBuffers_; 
383     bucketIndex = destinationCoordinates.columnIndex; 
384     break;
385   case PersonalizedMessage:
386     messageBuffers = personalizedBuffers_; 
387     bucketIndex = destinationCoordinates.rowIndex; 
388     break;
389   default: 
390     CkError("Unrecognized MeshStreamer message type\n");
391     break;
392   }
393
394   storeMessage(messageBuffers, bucketIndex, destinationPe, destinationCoordinates, 
395                dataItem);
396
397     // release control to scheduler if requested by the user, 
398     //   assume caller is threaded entry
399   if (yieldFlag_ && ++count % 1024 == 0) CthYield();
400 }
401
402 template <class dtype>
403 void MeshStreamer<dtype>::doneInserting() {
404   contribute(CkCallback(CkIndex_MeshStreamer<dtype>::finish(NULL), thisProxy));
405 }
406
407 template <class dtype>
408 void MeshStreamer<dtype>::finish(CkReductionMsg *msg) {
409
410   isPeriodicFlushEnabled_ = false; 
411   flushDirect();
412
413   if (!userCallback_.isInvalid()) {
414     CkStartQD(userCallback_);
415     userCallback_ = CkCallback();      // nullify the current callback
416   }
417
418   //  delete msg; 
419 }
420
421
422 template <class dtype>
423 void MeshStreamer<dtype>::receiveAggregateData(MeshStreamerMessage<dtype> *msg) {
424
425   int destinationPe; 
426   MeshStreamerMessageType msgType;   
427   MeshLocation destinationCoordinates;
428
429   for (int i = 0; i < msg->numDataItems; i++) {
430     destinationPe = msg->destinationPes[i];
431     dtype &dataItem = msg->getDataItem(i);
432     determineLocation(destinationPe, destinationCoordinates);
433 #ifdef DEBUG_STREAMER
434     CkAssert(destinationCoordinates.planeIndex == myPlaneIndex_);
435
436     if (destinationCoordinates.msgType == PersonalizedMessage) {
437       CkAssert(destinationCoordinates.columnIndex == myColumnIndex_);
438     }
439 #endif    
440
441     MeshStreamerMessage<dtype> **messageBuffers;
442     int bucketIndex; 
443
444     switch (destinationCoordinates.msgType) {
445     case ColumnMessage:
446       messageBuffers = columnBuffers_; 
447       bucketIndex = destinationCoordinates.columnIndex; 
448       break;
449     case PersonalizedMessage:
450       messageBuffers = personalizedBuffers_; 
451       bucketIndex = destinationCoordinates.rowIndex; 
452       break;
453     default: 
454       CkError("Incorrect MeshStreamer message type\n");
455       break;
456     }
457
458     storeMessage(messageBuffers, bucketIndex, destinationPe, 
459                  destinationCoordinates, dataItem);
460     
461   }
462
463   delete msg;
464
465 }
466
467 /*
468 void MeshStreamer::receivePersonalizedData(MeshStreamerMessage *msg) {
469
470   // sort data items into messages for each core on this node
471
472   LocalMessage *localMsgs[numPesPerNode_];
473   int dataSize = bucketSize_ * dataItemSize_;
474
475   for (int i = 0; i < numPesPerNode_; i++) {
476     localMsgs[i] = new (dataSize) LocalMessage(dataItemSize_);
477   }
478
479   int destinationPe;
480   for (int i = 0; i < msg->numDataItems; i++) {
481
482     destinationPe = msg->destinationPes[i]; 
483     void *dataItem = msg->getDataItem(i);   
484     localMsgs[destinationPe % numPesPerNode_]->addDataItem(dataItem);
485
486   }
487
488   for (int i = 0; i < numPesPerNode_; i++) {
489     if (localMsgs[i]->numDataItems > 0) {
490       clientProxy_[myNodeIndex_ * numPesPerNode_ + i].receiveCombinedData(localMsgs[i]);
491     }
492     else {
493       delete localMsgs[i];
494     }
495   }
496
497   delete msg; 
498
499 }
500 */
501
502 template <class dtype>
503 void MeshStreamer<dtype>::flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
504                                       const int numBuffers, const int myIndex, 
505                                       const int dimensionFactor) {
506
507   int flushIndex, maxSize, destinationIndex;
508   MeshStreamerMessage<dtype> *destinationBucket; 
509   maxSize = 0;
510   for (int i = 0; i < numBuffers; i++) {
511     if (messageBuffers[i] != NULL && messageBuffers[i]->numDataItems > maxSize) {
512       maxSize = messageBuffers[i]->numDataItems;
513       flushIndex = i;
514     } 
515   }
516   if (maxSize > 0) {
517     destinationBucket = messageBuffers[flushIndex];
518     destinationIndex = myNodeIndex_ + (flushIndex - myIndex) * dimensionFactor;
519
520     if (destinationBucket->numDataItems < bucketSize_) {
521       // not sending the full buffer, shrink the message size
522       envelope *env = UsrToEnv(destinationBucket);
523       env->setTotalsize(env->getTotalsize() - (bucketSize_ - destinationBucket->numDataItems) * sizeof(dtype));
524     }
525     numDataItemsBuffered_ -= destinationBucket->numDataItems;
526
527     if (messageBuffers == personalizedBuffers_) {
528       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);
529     }
530     else {
531       CProxy_MeshStreamer<dtype> thisProxy(thisgroup);
532       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
533     }
534     messageBuffers[flushIndex] = NULL;
535   }
536 }
537
538 template <class dtype>
539 void MeshStreamer<dtype>::flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers)
540 {
541
542     for (int i = 0; i < numBuffers; i++) {
543        if(messageBuffers[i] == NULL)
544            continue;
545        //flush all messages in i bucket
546        numDataItemsBuffered_ -= messageBuffers[i]->numDataItems;
547        if (messageBuffers == personalizedBuffers_) {
548          int destinationPe = myNodeIndex_ + (i - myRowIndex_) * numColumns_; 
549          clientProxy_[destinationPe].receiveCombinedData(messageBuffers[i]);
550        }
551        else {
552          for (int j = 0; j < messageBuffers[i]->numDataItems; j++) {
553            MeshStreamerMessage<dtype> *directMsg = 
554              new (0, 1) MeshStreamerMessage<dtype>();
555 #ifdef DEBUG_STREAMER
556            CkAssert(directMsg != NULL);
557 #endif
558            int destinationPe = messageBuffers[i]->destinationPes[j]; 
559            dtype dataItem = messageBuffers[i]->getDataItem(j);   
560            directMsg->addDataItem(dataItem);
561            clientProxy_[destinationPe].receiveCombinedData(directMsg);
562          }
563          delete messageBuffers[i];
564        }
565        messageBuffers[i] = NULL;
566     }
567
568 }
569
570 template <class dtype>
571 void MeshStreamer<dtype>::flushDirect(){
572
573     if (!isPeriodicFlushEnabled_ || 
574         1000 * (CkWallTimer() - timeOfLastSend_) >= progressPeriodInMs_) {
575       flushBuckets(planeBuffers_, numPlanes_);
576       flushBuckets(columnBuffers_, numColumns_);
577       flushBuckets(personalizedBuffers_, numRows_);
578     }
579
580     if (isPeriodicFlushEnabled_) {
581       timeOfLastSend_ = CkWallTimer();
582     }
583
584 #ifdef DEBUG_STREAMER
585     //CkPrintf("[%d] numDataItemsBuffered_: %d\n", CkMyPe(), numDataItemsBuffered_);
586     CkAssert(numDataItemsBuffered_ == 0); 
587 #endif
588
589 }
590
591 template <class dtype>
592 void periodicProgressFunction(void *MeshStreamerObj, double time) {
593
594   MeshStreamer<dtype> *properObj = 
595     static_cast<MeshStreamer<dtype>*>(MeshStreamerObj); 
596
597   if (properObj->isPeriodicFlushEnabled()) {
598     properObj->flushDirect();
599     properObj->registerPeriodicProgressFunction();
600   }
601 }
602
603 template <class dtype>
604 void MeshStreamer<dtype>::registerPeriodicProgressFunction() {
605   CcdCallFnAfter(periodicProgressFunction<dtype>, (void *) this, progressPeriodInMs_); 
606 }
607
608
609 #define CK_TEMPLATES_ONLY
610 #include "MeshStreamer.def.h"
611 #undef CK_TEMPLATES_ONLY
612
613 #endif