MeshStreamer: added support for incomplete meshes.
[charm.git] / src / libs / ck-libs / MeshStreamer / MeshStreamer.h
1 #ifndef _MESH_STREAMER_H_
2 #define _MESH_STREAMER_H_
3
4 #include <algorithm>
5 #include "MeshStreamer.decl.h"
6 // allocate more total buffer space than the maximum buffering limit but flush upon
7 // reaching totalBufferCapacity_
8 #define BUCKET_SIZE_FACTOR 4
9
10 // #define DEBUG_STREAMER
11 // #define CACHE_LOCATIONS
12 // #define SUPPORT_INCOMPLETE_MESH
13
14 enum MeshStreamerMessageType {PlaneMessage, ColumnMessage, PersonalizedMessage};
15
16 class MeshLocation {
17  public:
18   int rowIndex;
19   int columnIndex;
20   int planeIndex; 
21   MeshStreamerMessageType msgType;
22 };
23
24
25 /*
26 class LocalMessage : public CMessage_LocalMessage {
27 public:
28     int numDataItems; 
29     int dataItemSize; 
30     char *data;
31
32     LocalMessage(int dataItemSizeInBytes) {
33         numDataItems = 0; 
34         dataItemSize = dataItemSizeInBytes; 
35     }
36
37     int addDataItem(void *dataItem) {
38         CmiMemcpy(&data[numDataItems * dataItemSize], dataItem, dataItemSize);
39         return ++numDataItems; 
40     } 
41
42     void *getDataItem(int index) {
43         return (void *) (&data[index * dataItemSize]);  
44     }
45
46 };
47 */
48
49 template<class dtype>
50 class MeshStreamerMessage : public CMessage_MeshStreamerMessage<dtype> {
51 public:
52     int numDataItems;
53     int *destinationPes;
54     dtype *data;
55
56     MeshStreamerMessage(): numDataItems(0) {}   
57
58     int addDataItem(const dtype &dataItem) {
59         data[numDataItems] = dataItem;
60         return ++numDataItems; 
61     }
62
63     void markDestination(const int index, const int destinationPe) {
64         destinationPes[index] = destinationPe;
65     }
66
67     dtype &getDataItem(const int index) {
68         return data[index];
69     }
70 };
71
72 template <class dtype>
73 class MeshStreamerClient : public CBase_MeshStreamerClient<dtype> {
74  public:
75      virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
76      virtual void process(dtype &data)=0; 
77 };
78
79 template <class dtype>
80 class MeshStreamer : public CBase_MeshStreamer<dtype> {
81
82 private:
83     int bucketSize_; 
84     int totalBufferCapacity_;
85     int numDataItemsBuffered_;
86
87     int numNodes_; 
88     int numRows_; 
89     int numColumns_; 
90     int numPlanes_; 
91     int planeSize_;
92
93     CProxy_MeshStreamerClient<dtype> clientProxy_;
94     MeshStreamerClient<dtype> *clientObj_;
95
96     int myNodeIndex_;
97     int myPlaneIndex_;
98     int myColumnIndex_; 
99     int myRowIndex_;
100
101     CkCallback   userCallback_;
102     int yieldFlag_;
103
104     double progressPeriodInMs_; 
105     bool isPeriodicFlushEnabled_; 
106     double timeOfLastSend_; 
107
108     MeshStreamerMessage<dtype> **personalizedBuffers_; 
109     MeshStreamerMessage<dtype> **columnBuffers_; 
110     MeshStreamerMessage<dtype> **planeBuffers_;
111
112 #ifdef CACHE_LOCATIONS
113     MeshLocation *cachedLocations;
114     bool *isCached; 
115 #endif
116
117 #ifdef SUPPORT_INCOMPLETE_MESH
118     int numNodesInLastPlane_;
119     int numFullRowsInLastPlane_;
120     int numColumnsInLastRow_;
121 #endif
122
123     void determineLocation(const int destinationPe, 
124                            MeshLocation &destinationCoordinates);
125
126     void storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
127                       const int bucketIndex, const int destinationPe, 
128                       const MeshLocation &destinationCoordinates, const dtype &dataItem);
129
130     void flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
131                             const int numBuffers, const int myIndex, 
132                             const int dimensionFactor);
133
134 public:
135
136     MeshStreamer(int totalBufferCapacity, int numRows, 
137                  int numColumns, int numPlanes, 
138                  const CProxy_MeshStreamerClient<dtype> &clientProxy,
139                  int yieldFlag = 0, double progressPeriodInMs = -1.0);
140     ~MeshStreamer();
141
142       // entry
143     void insertData(dtype &dataItem, const int destinationPe); 
144     void doneInserting();
145     void receiveAggregateData(MeshStreamerMessage<dtype> *msg);
146     // void receivePersonalizedData(MeshStreamerMessage<dtype> *msg);
147
148     void flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers);
149     void flushDirect();
150
151     bool isPeriodicFlushEnabled() {
152       return isPeriodicFlushEnabled_;
153     }
154       // non entry
155     void associateCallback(CkCallback &cb, bool automaticFinish = true) { 
156       userCallback_ = cb;
157       if (automaticFinish) {
158         CkStartQD(CkCallback(CkIndex_MeshStreamer<dtype>::finish(NULL), this->thisProxy));
159       }
160     }
161
162     void registerPeriodicProgressFunction();
163     void finish(CkReductionMsg *msg);
164
165     /*
166      * Flushing begins on a PE only after enablePeriodicFlushing has been invoked.
167      */
168     void enablePeriodicFlushing(){
169       isPeriodicFlushEnabled_ = true; 
170       registerPeriodicProgressFunction();
171     }
172 };
173
174 template <class dtype>
175 void MeshStreamerClient<dtype>::receiveCombinedData(MeshStreamerMessage<dtype> *msg) {
176   for (int i = 0; i < msg->numDataItems; i++) {
177      dtype data = ((dtype*)(msg->data))[i];
178      process(data);
179   }
180   delete msg;
181 }
182
183 template <class dtype>
184 MeshStreamer<dtype>::MeshStreamer(int totalBufferCapacity, int numRows, 
185                            int numColumns, int numPlanes, 
186                            const CProxy_MeshStreamerClient<dtype> &clientProxy,
187                            int yieldFlag, double progressPeriodInMs): yieldFlag_(yieldFlag) {
188   // limit total number of messages in system to totalBufferCapacity
189   //   but allocate a factor BUCKET_SIZE_FACTOR more space to take
190   //   advantage of nonuniform filling of buckets
191   // the buffers for your own column and plane are never used
192   bucketSize_ = BUCKET_SIZE_FACTOR * totalBufferCapacity / (numRows + numColumns + numPlanes - 2); 
193   totalBufferCapacity_ = totalBufferCapacity;
194   numDataItemsBuffered_ = 0; 
195   numRows_ = numRows; 
196   numColumns_ = numColumns;
197   numPlanes_ = numPlanes; 
198   numNodes_ = CkNumPes(); 
199   clientProxy_ = clientProxy; 
200   clientObj_ = ((MeshStreamerClient<dtype> *)CkLocalBranch(clientProxy_));
201   progressPeriodInMs_ = progressPeriodInMs; 
202
203   personalizedBuffers_ = new MeshStreamerMessage<dtype> *[numRows];
204   for (int i = 0; i < numRows; i++) {
205     personalizedBuffers_[i] = NULL; 
206   }
207
208   columnBuffers_ = new MeshStreamerMessage<dtype> *[numColumns];
209   for (int i = 0; i < numColumns; i++) {
210     columnBuffers_[i] = NULL; 
211   }
212
213   planeBuffers_ = new MeshStreamerMessage<dtype> *[numPlanes]; 
214   for (int i = 0; i < numPlanes; i++) {
215     planeBuffers_[i] = NULL; 
216   }
217
218   // determine plane, column, and row location of this node
219   myNodeIndex_ = CkMyPe();
220   planeSize_ = numRows_ * numColumns_; 
221   myPlaneIndex_ = myNodeIndex_ / planeSize_; 
222   int indexWithinPlane = myNodeIndex_ - myPlaneIndex_ * planeSize_;
223   myRowIndex_ = indexWithinPlane / numColumns_;
224   myColumnIndex_ = indexWithinPlane - myRowIndex_ * numColumns_; 
225
226   isPeriodicFlushEnabled_ = false; 
227
228 #ifdef CACHE_LOCATIONS
229   cachedLocations = new MeshLocation[numNodes_];
230   isCached = new bool[numNodes_];
231   std::fill(isCached, isCached + numNodes_, false);
232 #endif
233
234 #ifdef SUPPORT_INCOMPLETE_MESH
235   numNodesInLastPlane_ = numNodes_ % planeSize_; 
236   numFullRowsInLastPlane_ = numNodesInLastPlane_ / numColumns_;
237   numColumnsInLastRow_ = numNodesInLastPlane_ - numFullRowsInLastPlane_ * numColumns_;  
238 #endif
239 }
240
241 template <class dtype>
242 MeshStreamer<dtype>::~MeshStreamer() {
243
244   for (int i = 0; i < numRows_; i++)
245       delete personalizedBuffers_[i]; 
246
247   for (int i = 0; i < numColumns_; i++)
248       delete columnBuffers_[i]; 
249
250   for (int i = 0; i < numPlanes_; i++)
251       delete planeBuffers_[i]; 
252
253   delete[] personalizedBuffers_;
254   delete[] columnBuffers_;
255   delete[] planeBuffers_; 
256
257 }
258
259 template <class dtype>
260 void MeshStreamer<dtype>::determineLocation(const int destinationPe, 
261                                             MeshLocation &destinationCoordinates) { 
262
263   int nodeIndex, indexWithinPlane; 
264
265 #ifdef CACHE_LOCATIONS
266   if (isCached[destinationPe] == true) {
267     destinationCoordinates = cachedLocations[destinationPe]; 
268     return;
269   }
270 #endif
271
272   nodeIndex = destinationPe;
273   destinationCoordinates.planeIndex = nodeIndex / planeSize_; 
274   if (destinationCoordinates.planeIndex != myPlaneIndex_) {
275     destinationCoordinates.msgType = PlaneMessage;     
276   }
277   else {
278     indexWithinPlane = 
279       nodeIndex - destinationCoordinates.planeIndex * planeSize_;
280     destinationCoordinates.rowIndex = indexWithinPlane / numColumns_;
281     destinationCoordinates.columnIndex = 
282       indexWithinPlane - destinationCoordinates.rowIndex * numColumns_; 
283     if (destinationCoordinates.columnIndex != myColumnIndex_) {
284       destinationCoordinates.msgType = ColumnMessage; 
285     }
286     else {
287       destinationCoordinates.msgType = PersonalizedMessage;
288     }
289   }
290
291 #ifdef CACHE_LOCATIONS
292   cachedLocations[destinationPe] = destinationCoordinates;
293 #endif
294
295 }
296
297 template <class dtype>
298 void MeshStreamer<dtype>::storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
299                                        const int bucketIndex, const int destinationPe, 
300                                        const MeshLocation& destinationCoordinates,
301                                        const dtype &dataItem) {
302
303   // allocate new message if necessary
304   if (messageBuffers[bucketIndex] == NULL) {
305     if (destinationCoordinates.msgType == PersonalizedMessage) {
306       messageBuffers[bucketIndex] = 
307         new (0, bucketSize_) MeshStreamerMessage<dtype>();
308     }
309     else {
310       messageBuffers[bucketIndex] = 
311         new (bucketSize_, bucketSize_) MeshStreamerMessage<dtype>();
312     }
313 #ifdef DEBUG_STREAMER
314     CkAssert(messageBuffers[bucketIndex] != NULL);
315 #endif
316   }
317   
318   MeshStreamerMessage<dtype> *destinationBucket = messageBuffers[bucketIndex];
319   
320   int numBuffered = destinationBucket->addDataItem(dataItem); 
321   if (destinationCoordinates.msgType != PersonalizedMessage) {
322     destinationBucket->markDestination(numBuffered-1, destinationPe);
323   }
324   numDataItemsBuffered_++;
325   // copy data into message and send if buffer is full
326   if (numBuffered == bucketSize_) {
327     int destinationIndex;
328     switch (destinationCoordinates.msgType) {
329
330     case PlaneMessage:
331       destinationIndex = myNodeIndex_ + 
332         (destinationCoordinates.planeIndex - myPlaneIndex_) * planeSize_;  
333 #ifdef SUPPORT_INCOMPLETE_MESH
334       if (destinationIndex >= numNodes_) {
335         int numValidRows = numFullRowsInLastPlane_; 
336         if (numColumnsInLastRow_ > myColumnIndex_) {
337           numValidRows++; 
338         }
339         destinationIndex = destinationCoordinates.planeIndex * planeSize_ + 
340           myColumnIndex_ + (myRowIndex_ % numValidRows) * numColumns_; 
341       }
342 #endif      
343       this->thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
344       break;
345     case ColumnMessage:
346       destinationIndex = myNodeIndex_ + 
347         (destinationCoordinates.columnIndex - myColumnIndex_);
348 #ifdef SUPPORT_INCOMPLETE_MESH
349       if (destinationIndex >= numNodes_) {
350         destinationIndex = destinationCoordinates.planeIndex * planeSize_ + 
351           destinationCoordinates.columnIndex + 
352           (myColumnIndex_ % numFullRowsInLastPlane_) * numColumns_; 
353       }
354 #endif      
355       this->thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
356       break;
357     case PersonalizedMessage:
358       destinationIndex = myNodeIndex_ + 
359         (destinationCoordinates.rowIndex - myRowIndex_) * numColumns_;
360       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);      
361       //      this->thisProxy[destinationIndex].receivePersonalizedData(destinationBucket);
362       break;
363     default: 
364       CkError("Incorrect MeshStreamer message type\n");
365       break;
366     }
367     messageBuffers[bucketIndex] = NULL;
368     numDataItemsBuffered_ -= numBuffered; 
369
370     if (isPeriodicFlushEnabled_) {
371       timeOfLastSend_ = CkWallTimer();
372     }
373
374   }
375
376   if (numDataItemsBuffered_ == totalBufferCapacity_) {
377
378     flushLargestBucket(personalizedBuffers_, numRows_, myRowIndex_, numColumns_);
379     flushLargestBucket(columnBuffers_, numColumns_, myColumnIndex_, 1);
380     flushLargestBucket(planeBuffers_, numPlanes_, myPlaneIndex_, planeSize_);
381
382     if (isPeriodicFlushEnabled_) {
383       timeOfLastSend_ = CkWallTimer();
384     }
385
386   }
387
388 }
389
390 template <class dtype>
391 void MeshStreamer<dtype>::insertData(dtype &dataItem, const int destinationPe) {
392   static int count = 0;
393
394   if (destinationPe == CkMyPe()) {
395     clientObj_->process(dataItem);
396     return;
397   }
398
399   MeshLocation destinationCoordinates;
400
401   determineLocation(destinationPe, destinationCoordinates);
402
403   // determine which array of buffers is appropriate for this message
404   MeshStreamerMessage<dtype> **messageBuffers;
405   int bucketIndex; 
406
407   switch (destinationCoordinates.msgType) {
408   case PlaneMessage:
409     messageBuffers = planeBuffers_; 
410     bucketIndex = destinationCoordinates.planeIndex; 
411     break;
412   case ColumnMessage:
413     messageBuffers = columnBuffers_; 
414     bucketIndex = destinationCoordinates.columnIndex; 
415     break;
416   case PersonalizedMessage:
417     messageBuffers = personalizedBuffers_; 
418     bucketIndex = destinationCoordinates.rowIndex; 
419     break;
420   default: 
421     CkError("Unrecognized MeshStreamer message type\n");
422     break;
423   }
424
425   storeMessage(messageBuffers, bucketIndex, destinationPe, destinationCoordinates, 
426                dataItem);
427
428     // release control to scheduler if requested by the user, 
429     //   assume caller is threaded entry
430   if (yieldFlag_ && ++count == 1024) {
431     count = 0; 
432     CthYield();
433   }
434 }
435
436 template <class dtype>
437 void MeshStreamer<dtype>::doneInserting() {
438   this->contribute(CkCallback(CkIndex_MeshStreamer<dtype>::finish(NULL), this->thisProxy));
439 }
440
441 template <class dtype>
442 void MeshStreamer<dtype>::finish(CkReductionMsg *msg) {
443
444   isPeriodicFlushEnabled_ = false; 
445   flushDirect();
446
447   if (!userCallback_.isInvalid()) {
448     CkStartQD(userCallback_);
449     userCallback_ = CkCallback();      // nullify the current callback
450   }
451
452   //  delete msg; 
453 }
454
455
456 template <class dtype>
457 void MeshStreamer<dtype>::receiveAggregateData(MeshStreamerMessage<dtype> *msg) {
458
459   int destinationPe; 
460   MeshStreamerMessageType msgType;   
461   MeshLocation destinationCoordinates;
462
463   for (int i = 0; i < msg->numDataItems; i++) {
464     destinationPe = msg->destinationPes[i];
465     dtype &dataItem = msg->getDataItem(i);
466     determineLocation(destinationPe, destinationCoordinates);
467 #ifdef DEBUG_STREAMER
468     CkAssert(destinationCoordinates.planeIndex == myPlaneIndex_);
469
470     if (destinationCoordinates.msgType == PersonalizedMessage) {
471       CkAssert(destinationCoordinates.columnIndex == myColumnIndex_);
472     }
473 #endif    
474
475     MeshStreamerMessage<dtype> **messageBuffers;
476     int bucketIndex; 
477
478     switch (destinationCoordinates.msgType) {
479     case ColumnMessage:
480       messageBuffers = columnBuffers_; 
481       bucketIndex = destinationCoordinates.columnIndex; 
482       break;
483     case PersonalizedMessage:
484       messageBuffers = personalizedBuffers_; 
485       bucketIndex = destinationCoordinates.rowIndex; 
486       break;
487     default: 
488       CkError("Incorrect MeshStreamer message type\n");
489       break;
490     }
491
492     storeMessage(messageBuffers, bucketIndex, destinationPe, 
493                  destinationCoordinates, dataItem);
494     
495   }
496
497   delete msg;
498
499 }
500
501 /*
502 void MeshStreamer::receivePersonalizedData(MeshStreamerMessage *msg) {
503
504   // sort data items into messages for each core on this node
505
506   LocalMessage *localMsgs[numPesPerNode_];
507   int dataSize = bucketSize_ * dataItemSize_;
508
509   for (int i = 0; i < numPesPerNode_; i++) {
510     localMsgs[i] = new (dataSize) LocalMessage(dataItemSize_);
511   }
512
513   int destinationPe;
514   for (int i = 0; i < msg->numDataItems; i++) {
515
516     destinationPe = msg->destinationPes[i]; 
517     void *dataItem = msg->getDataItem(i);   
518     localMsgs[destinationPe % numPesPerNode_]->addDataItem(dataItem);
519
520   }
521
522   for (int i = 0; i < numPesPerNode_; i++) {
523     if (localMsgs[i]->numDataItems > 0) {
524       clientProxy_[myNodeIndex_ * numPesPerNode_ + i].receiveCombinedData(localMsgs[i]);
525     }
526     else {
527       delete localMsgs[i];
528     }
529   }
530
531   delete msg; 
532
533 }
534 */
535
536 template <class dtype>
537 void MeshStreamer<dtype>::flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
538                                       const int numBuffers, const int myIndex, 
539                                       const int dimensionFactor) {
540
541   int flushIndex, maxSize, destinationIndex;
542   MeshStreamerMessage<dtype> *destinationBucket; 
543   maxSize = 0;
544   for (int i = 0; i < numBuffers; i++) {
545     if (messageBuffers[i] != NULL && messageBuffers[i]->numDataItems > maxSize) {
546       maxSize = messageBuffers[i]->numDataItems;
547       flushIndex = i;
548     } 
549   }
550   if (maxSize > 0) {
551     destinationBucket = messageBuffers[flushIndex];
552     destinationIndex = myNodeIndex_ + (flushIndex - myIndex) * dimensionFactor;
553
554     if (destinationBucket->numDataItems < bucketSize_) {
555       // not sending the full buffer, shrink the message size
556       envelope *env = UsrToEnv(destinationBucket);
557       env->setTotalsize(env->getTotalsize() - (bucketSize_ - destinationBucket->numDataItems) * sizeof(dtype));
558     }
559     numDataItemsBuffered_ -= destinationBucket->numDataItems;
560
561     if (messageBuffers == personalizedBuffers_) {
562       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);
563     }
564     else {
565       this->thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
566     }
567     messageBuffers[flushIndex] = NULL;
568   }
569 }
570
571 template <class dtype>
572 void MeshStreamer<dtype>::flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers)
573 {
574
575     for (int i = 0; i < numBuffers; i++) {
576        if(messageBuffers[i] == NULL)
577            continue;
578        //flush all messages in i bucket
579        numDataItemsBuffered_ -= messageBuffers[i]->numDataItems;
580        if (messageBuffers == personalizedBuffers_) {
581          int destinationPe = myNodeIndex_ + (i - myRowIndex_) * numColumns_; 
582          clientProxy_[destinationPe].receiveCombinedData(messageBuffers[i]);
583        }
584        else {
585          for (int j = 0; j < messageBuffers[i]->numDataItems; j++) {
586            MeshStreamerMessage<dtype> *directMsg = 
587              new (0, 1) MeshStreamerMessage<dtype>();
588 #ifdef DEBUG_STREAMER
589            CkAssert(directMsg != NULL);
590 #endif
591            int destinationPe = messageBuffers[i]->destinationPes[j]; 
592            dtype dataItem = messageBuffers[i]->getDataItem(j);   
593            directMsg->addDataItem(dataItem);
594            clientProxy_[destinationPe].receiveCombinedData(directMsg);
595          }
596          delete messageBuffers[i];
597        }
598        messageBuffers[i] = NULL;
599     }
600
601 }
602
603 template <class dtype>
604 void MeshStreamer<dtype>::flushDirect(){
605
606     if (!isPeriodicFlushEnabled_ || 
607         1000 * (CkWallTimer() - timeOfLastSend_) >= progressPeriodInMs_) {
608       flushBuckets(planeBuffers_, numPlanes_);
609       flushBuckets(columnBuffers_, numColumns_);
610       flushBuckets(personalizedBuffers_, numRows_);
611     }
612
613     if (isPeriodicFlushEnabled_) {
614       timeOfLastSend_ = CkWallTimer();
615     }
616
617 #ifdef DEBUG_STREAMER
618     //CkPrintf("[%d] numDataItemsBuffered_: %d\n", CkMyPe(), numDataItemsBuffered_);
619     CkAssert(numDataItemsBuffered_ == 0); 
620 #endif
621
622 }
623
624 template <class dtype>
625 void periodicProgressFunction(void *MeshStreamerObj, double time) {
626
627   MeshStreamer<dtype> *properObj = 
628     static_cast<MeshStreamer<dtype>*>(MeshStreamerObj); 
629
630   if (properObj->isPeriodicFlushEnabled()) {
631     properObj->flushDirect();
632     properObj->registerPeriodicProgressFunction();
633   }
634 }
635
636 template <class dtype>
637 void MeshStreamer<dtype>::registerPeriodicProgressFunction() {
638   CcdCallFnAfter(periodicProgressFunction<dtype>, (void *) this, progressPeriodInMs_); 
639 }
640
641
642 #define CK_TEMPLATES_ONLY
643 #include "MeshStreamer.def.h"
644 #undef CK_TEMPLATES_ONLY
645
646 #endif