MeshStreamer: rewrote location caching scheme using arrays instead of a map
[charm.git] / src / libs / ck-libs / MeshStreamer / MeshStreamer.h
1 #ifndef _MESH_STREAMER_H_
2 #define _MESH_STREAMER_H_
3
4 #include <algorithm>
5 #include "MeshStreamer.decl.h"
6 // allocate more total buffer space than the maximum buffering limit but flush upon
7 // reaching totalBufferCapacity_
8 #define BUCKET_SIZE_FACTOR 4
9
10 //#define DEBUG_STREAMER
11
12 enum MeshStreamerMessageType {PlaneMessage, ColumnMessage, PersonalizedMessage};
13
14 class MeshLocation {
15  public:
16   int rowIndex;
17   int columnIndex;
18   int planeIndex; 
19   MeshStreamerMessageType msgType;
20 };
21
22 // #define CACHE_LOCATIONS
23
24 /*
25 class LocalMessage : public CMessage_LocalMessage {
26 public:
27     int numDataItems; 
28     int dataItemSize; 
29     char *data;
30
31     LocalMessage(int dataItemSizeInBytes) {
32         numDataItems = 0; 
33         dataItemSize = dataItemSizeInBytes; 
34     }
35
36     int addDataItem(void *dataItem) {
37         CmiMemcpy(&data[numDataItems * dataItemSize], dataItem, dataItemSize);
38         return ++numDataItems; 
39     } 
40
41     void *getDataItem(int index) {
42         return (void *) (&data[index * dataItemSize]);  
43     }
44
45 };
46 */
47
48 template<class dtype>
49 class MeshStreamerMessage : public CMessage_MeshStreamerMessage<dtype> {
50 public:
51     int numDataItems;
52     int *destinationPes;
53     dtype *data;
54
55     MeshStreamerMessage(): numDataItems(0) {}   
56
57     int addDataItem(const dtype &dataItem) {
58         data[numDataItems] = dataItem;
59         return ++numDataItems; 
60     }
61
62     void markDestination(const int index, const int destinationPe) {
63         destinationPes[index] = destinationPe;
64     }
65
66     dtype &getDataItem(const int index) {
67         return data[index];
68     }
69 };
70
71 template <class dtype>
72 class MeshStreamerClient : public CBase_MeshStreamerClient<dtype> {
73  public:
74      virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
75      virtual void process(dtype &data)=0; 
76 };
77
78 template <class dtype>
79 class MeshStreamer : public CBase_MeshStreamer<dtype> {
80
81 private:
82     int bucketSize_; 
83     int totalBufferCapacity_;
84     int numDataItemsBuffered_;
85
86     int numNodes_; 
87     int numRows_; 
88     int numColumns_; 
89     int numPlanes_; 
90     int planeSize_;
91
92     CProxy_MeshStreamerClient<dtype> clientProxy_;
93     MeshStreamerClient<dtype> *clientObj_;
94
95     int myNodeIndex_;
96     int myPlaneIndex_;
97     int myColumnIndex_; 
98     int myRowIndex_;
99
100     CkCallback   userCallback_;
101     int yieldFlag_;
102
103     double progressPeriodInMs_; 
104     bool isPeriodicFlushEnabled_; 
105     double timeOfLastSend_; 
106
107     MeshStreamerMessage<dtype> **personalizedBuffers_; 
108     MeshStreamerMessage<dtype> **columnBuffers_; 
109     MeshStreamerMessage<dtype> **planeBuffers_;
110
111 #ifdef CACHE_LOCATIONS
112     MeshLocation *cachedLocations;
113     bool *isCached; 
114 #endif
115
116     void determineLocation(const int destinationPe, 
117                            MeshLocation &destinationCoordinates);
118
119     void storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
120                       const int bucketIndex, const int destinationPe, 
121                       const MeshLocation &destinationCoordinates, const dtype &dataItem);
122
123     void flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
124                             const int numBuffers, const int myIndex, 
125                             const int dimensionFactor);
126
127 public:
128
129     MeshStreamer(int totalBufferCapacity, int numRows, 
130                  int numColumns, int numPlanes, 
131                  const CProxy_MeshStreamerClient<dtype> &clientProxy,
132                  int yieldFlag = 0, double progressPeriodInMs = -1.0);
133     ~MeshStreamer();
134
135       // entry
136     void insertData(dtype &dataItem, const int destinationPe); 
137     void doneInserting();
138     void receiveAggregateData(MeshStreamerMessage<dtype> *msg);
139     // void receivePersonalizedData(MeshStreamerMessage<dtype> *msg);
140
141     void flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers);
142     void flushDirect();
143
144     bool isPeriodicFlushEnabled() {
145       return isPeriodicFlushEnabled_;
146     }
147       // non entry
148     void associateCallback(CkCallback &cb, bool automaticFinish = true) { 
149       userCallback_ = cb;
150       if (automaticFinish) {
151         CkStartQD(CkCallback(CkIndex_MeshStreamer<dtype>::finish(NULL), this->thisProxy));
152       }
153     }
154
155     void registerPeriodicProgressFunction();
156     void finish(CkReductionMsg *msg);
157
158     /*
159      * Flushing begins on a PE only after enablePeriodicFlushing has been invoked.
160      */
161     void enablePeriodicFlushing(){
162       isPeriodicFlushEnabled_ = true; 
163       registerPeriodicProgressFunction();
164     }
165 };
166
167 template <class dtype>
168 void MeshStreamerClient<dtype>::receiveCombinedData(MeshStreamerMessage<dtype> *msg) {
169   for (int i = 0; i < msg->numDataItems; i++) {
170      dtype data = ((dtype*)(msg->data))[i];
171      process(data);
172   }
173   delete msg;
174 }
175
176 template <class dtype>
177 MeshStreamer<dtype>::MeshStreamer(int totalBufferCapacity, int numRows, 
178                            int numColumns, int numPlanes, 
179                            const CProxy_MeshStreamerClient<dtype> &clientProxy,
180                            int yieldFlag, double progressPeriodInMs): yieldFlag_(yieldFlag) {
181   // limit total number of messages in system to totalBufferCapacity
182   //   but allocate a factor BUCKET_SIZE_FACTOR more space to take
183   //   advantage of nonuniform filling of buckets
184   // the buffers for your own column and plane are never used
185   bucketSize_ = BUCKET_SIZE_FACTOR * totalBufferCapacity / (numRows + numColumns + numPlanes - 2); 
186   totalBufferCapacity_ = totalBufferCapacity;
187   numDataItemsBuffered_ = 0; 
188   numRows_ = numRows; 
189   numColumns_ = numColumns;
190   numPlanes_ = numPlanes; 
191   numNodes_ = CkNumPes(); 
192   clientProxy_ = clientProxy; 
193   clientObj_ = ((MeshStreamerClient<dtype> *)CkLocalBranch(clientProxy_));
194   progressPeriodInMs_ = progressPeriodInMs; 
195
196   personalizedBuffers_ = new MeshStreamerMessage<dtype> *[numRows];
197   for (int i = 0; i < numRows; i++) {
198     personalizedBuffers_[i] = NULL; 
199   }
200
201   columnBuffers_ = new MeshStreamerMessage<dtype> *[numColumns];
202   for (int i = 0; i < numColumns; i++) {
203     columnBuffers_[i] = NULL; 
204   }
205
206   planeBuffers_ = new MeshStreamerMessage<dtype> *[numPlanes]; 
207   for (int i = 0; i < numPlanes; i++) {
208     planeBuffers_[i] = NULL; 
209   }
210
211   // determine plane, column, and row location of this node
212   myNodeIndex_ = CkMyPe();
213   planeSize_ = numRows_ * numColumns_; 
214   myPlaneIndex_ = myNodeIndex_ / planeSize_; 
215   int indexWithinPlane = myNodeIndex_ - myPlaneIndex_ * planeSize_;
216   myRowIndex_ = indexWithinPlane / numColumns_;
217   myColumnIndex_ = indexWithinPlane - myRowIndex_ * numColumns_; 
218
219   isPeriodicFlushEnabled_ = false; 
220
221 #ifdef CACHE_LOCATIONS
222   cachedLocations = new MeshLocation[numNodes_];
223   isCached = new bool[numNodes_];
224   std::fill(isCached, isCached + numNodes_, false);
225 #endif
226
227 }
228
229 template <class dtype>
230 MeshStreamer<dtype>::~MeshStreamer() {
231
232   for (int i = 0; i < numRows_; i++)
233       delete personalizedBuffers_[i]; 
234
235   for (int i = 0; i < numColumns_; i++)
236       delete columnBuffers_[i]; 
237
238   for (int i = 0; i < numPlanes_; i++)
239       delete planeBuffers_[i]; 
240
241   delete[] personalizedBuffers_;
242   delete[] columnBuffers_;
243   delete[] planeBuffers_; 
244
245 }
246
247 template <class dtype>
248 void MeshStreamer<dtype>::determineLocation(const int destinationPe, 
249                                             MeshLocation &destinationCoordinates) { 
250
251   int nodeIndex, indexWithinPlane; 
252
253 #ifdef CACHE_LOCATIONS
254   if (isCached[destinationPe] == true) {
255     destinationCoordinates = cachedLocations[destinationPe]; 
256     return;
257   }
258 #endif
259
260   nodeIndex = destinationPe;
261   destinationCoordinates.planeIndex = nodeIndex / planeSize_; 
262   if (destinationCoordinates.planeIndex != myPlaneIndex_) {
263     destinationCoordinates.msgType = PlaneMessage;     
264   }
265   else {
266     indexWithinPlane = 
267       nodeIndex - destinationCoordinates.planeIndex * planeSize_;
268     destinationCoordinates.rowIndex = indexWithinPlane / numColumns_;
269     destinationCoordinates.columnIndex = 
270       indexWithinPlane - destinationCoordinates.rowIndex * numColumns_; 
271     if (destinationCoordinates.columnIndex != myColumnIndex_) {
272       destinationCoordinates.msgType = ColumnMessage; 
273     }
274     else {
275       destinationCoordinates.msgType = PersonalizedMessage;
276     }
277   }
278
279 #ifdef CACHE_LOCATIONS
280   cachedLocations[destinationPe] = destinationCoordinates;
281 #endif
282
283 }
284
285 template <class dtype>
286 void MeshStreamer<dtype>::storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
287                                        const int bucketIndex, const int destinationPe, 
288                                        const MeshLocation& destinationCoordinates,
289                                        const dtype &dataItem) {
290
291   // allocate new message if necessary
292   if (messageBuffers[bucketIndex] == NULL) {
293     if (destinationCoordinates.msgType == PersonalizedMessage) {
294       messageBuffers[bucketIndex] = 
295         new (0, bucketSize_) MeshStreamerMessage<dtype>();
296     }
297     else {
298       messageBuffers[bucketIndex] = 
299         new (bucketSize_, bucketSize_) MeshStreamerMessage<dtype>();
300     }
301 #ifdef DEBUG_STREAMER
302     CkAssert(messageBuffers[bucketIndex] != NULL);
303 #endif
304   }
305   
306   MeshStreamerMessage<dtype> *destinationBucket = messageBuffers[bucketIndex];
307   
308   int numBuffered = destinationBucket->addDataItem(dataItem); 
309   if (destinationCoordinates.msgType != PersonalizedMessage) {
310     destinationBucket->markDestination(numBuffered-1, destinationPe);
311   }
312   numDataItemsBuffered_++;
313   // copy data into message and send if buffer is full
314   if (numBuffered == bucketSize_) {
315     int destinationIndex;
316     switch (destinationCoordinates.msgType) {
317
318     case PlaneMessage:
319       destinationIndex = myNodeIndex_ + 
320         (destinationCoordinates.planeIndex - myPlaneIndex_) * planeSize_;  
321       this->thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
322       break;
323     case ColumnMessage:
324       destinationIndex = myNodeIndex_ + 
325         (destinationCoordinates.columnIndex - myColumnIndex_);
326       this->thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
327       break;
328     case PersonalizedMessage:
329       destinationIndex = myNodeIndex_ + 
330         (destinationCoordinates.rowIndex - myRowIndex_) * numColumns_;
331       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);      
332       //      this->thisProxy[destinationIndex].receivePersonalizedData(destinationBucket);
333       break;
334     default: 
335       CkError("Incorrect MeshStreamer message type\n");
336       break;
337     }
338     messageBuffers[bucketIndex] = NULL;
339     numDataItemsBuffered_ -= numBuffered; 
340
341     if (isPeriodicFlushEnabled_) {
342       timeOfLastSend_ = CkWallTimer();
343     }
344
345   }
346
347   if (numDataItemsBuffered_ == totalBufferCapacity_) {
348
349     flushLargestBucket(personalizedBuffers_, numRows_, myRowIndex_, numColumns_);
350     flushLargestBucket(columnBuffers_, numColumns_, myColumnIndex_, 1);
351     flushLargestBucket(planeBuffers_, numPlanes_, myPlaneIndex_, planeSize_);
352
353     if (isPeriodicFlushEnabled_) {
354       timeOfLastSend_ = CkWallTimer();
355     }
356
357   }
358
359 }
360
361 template <class dtype>
362 void MeshStreamer<dtype>::insertData(dtype &dataItem, const int destinationPe) {
363   static int count = 0;
364
365   if (destinationPe == CkMyPe()) {
366     clientObj_->process(dataItem);
367     return;
368   }
369
370   int indexWithinPlane; 
371   MeshLocation destinationCoordinates;
372
373   determineLocation(destinationPe, destinationCoordinates);
374
375   // determine which array of buffers is appropriate for this message
376   MeshStreamerMessage<dtype> **messageBuffers;
377   int bucketIndex; 
378
379   switch (destinationCoordinates.msgType) {
380   case PlaneMessage:
381     messageBuffers = planeBuffers_; 
382     bucketIndex = destinationCoordinates.planeIndex; 
383     break;
384   case ColumnMessage:
385     messageBuffers = columnBuffers_; 
386     bucketIndex = destinationCoordinates.columnIndex; 
387     break;
388   case PersonalizedMessage:
389     messageBuffers = personalizedBuffers_; 
390     bucketIndex = destinationCoordinates.rowIndex; 
391     break;
392   default: 
393     CkError("Unrecognized MeshStreamer message type\n");
394     break;
395   }
396
397   storeMessage(messageBuffers, bucketIndex, destinationPe, destinationCoordinates, 
398                dataItem);
399
400     // release control to scheduler if requested by the user, 
401     //   assume caller is threaded entry
402   if (yieldFlag_ && ++count % 1024 == 0) CthYield();
403 }
404
405 template <class dtype>
406 void MeshStreamer<dtype>::doneInserting() {
407   this->contribute(CkCallback(CkIndex_MeshStreamer<dtype>::finish(NULL), this->thisProxy));
408 }
409
410 template <class dtype>
411 void MeshStreamer<dtype>::finish(CkReductionMsg *msg) {
412
413   isPeriodicFlushEnabled_ = false; 
414   flushDirect();
415
416   if (!userCallback_.isInvalid()) {
417     CkStartQD(userCallback_);
418     userCallback_ = CkCallback();      // nullify the current callback
419   }
420
421   //  delete msg; 
422 }
423
424
425 template <class dtype>
426 void MeshStreamer<dtype>::receiveAggregateData(MeshStreamerMessage<dtype> *msg) {
427
428   int destinationPe; 
429   MeshStreamerMessageType msgType;   
430   MeshLocation destinationCoordinates;
431
432   for (int i = 0; i < msg->numDataItems; i++) {
433     destinationPe = msg->destinationPes[i];
434     dtype &dataItem = msg->getDataItem(i);
435     determineLocation(destinationPe, destinationCoordinates);
436 #ifdef DEBUG_STREAMER
437     CkAssert(destinationCoordinates.planeIndex == myPlaneIndex_);
438
439     if (destinationCoordinates.msgType == PersonalizedMessage) {
440       CkAssert(destinationCoordinates.columnIndex == myColumnIndex_);
441     }
442 #endif    
443
444     MeshStreamerMessage<dtype> **messageBuffers;
445     int bucketIndex; 
446
447     switch (destinationCoordinates.msgType) {
448     case ColumnMessage:
449       messageBuffers = columnBuffers_; 
450       bucketIndex = destinationCoordinates.columnIndex; 
451       break;
452     case PersonalizedMessage:
453       messageBuffers = personalizedBuffers_; 
454       bucketIndex = destinationCoordinates.rowIndex; 
455       break;
456     default: 
457       CkError("Incorrect MeshStreamer message type\n");
458       break;
459     }
460
461     storeMessage(messageBuffers, bucketIndex, destinationPe, 
462                  destinationCoordinates, dataItem);
463     
464   }
465
466   delete msg;
467
468 }
469
470 /*
471 void MeshStreamer::receivePersonalizedData(MeshStreamerMessage *msg) {
472
473   // sort data items into messages for each core on this node
474
475   LocalMessage *localMsgs[numPesPerNode_];
476   int dataSize = bucketSize_ * dataItemSize_;
477
478   for (int i = 0; i < numPesPerNode_; i++) {
479     localMsgs[i] = new (dataSize) LocalMessage(dataItemSize_);
480   }
481
482   int destinationPe;
483   for (int i = 0; i < msg->numDataItems; i++) {
484
485     destinationPe = msg->destinationPes[i]; 
486     void *dataItem = msg->getDataItem(i);   
487     localMsgs[destinationPe % numPesPerNode_]->addDataItem(dataItem);
488
489   }
490
491   for (int i = 0; i < numPesPerNode_; i++) {
492     if (localMsgs[i]->numDataItems > 0) {
493       clientProxy_[myNodeIndex_ * numPesPerNode_ + i].receiveCombinedData(localMsgs[i]);
494     }
495     else {
496       delete localMsgs[i];
497     }
498   }
499
500   delete msg; 
501
502 }
503 */
504
505 template <class dtype>
506 void MeshStreamer<dtype>::flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
507                                       const int numBuffers, const int myIndex, 
508                                       const int dimensionFactor) {
509
510   int flushIndex, maxSize, destinationIndex;
511   MeshStreamerMessage<dtype> *destinationBucket; 
512   maxSize = 0;
513   for (int i = 0; i < numBuffers; i++) {
514     if (messageBuffers[i] != NULL && messageBuffers[i]->numDataItems > maxSize) {
515       maxSize = messageBuffers[i]->numDataItems;
516       flushIndex = i;
517     } 
518   }
519   if (maxSize > 0) {
520     destinationBucket = messageBuffers[flushIndex];
521     destinationIndex = myNodeIndex_ + (flushIndex - myIndex) * dimensionFactor;
522
523     if (destinationBucket->numDataItems < bucketSize_) {
524       // not sending the full buffer, shrink the message size
525       envelope *env = UsrToEnv(destinationBucket);
526       env->setTotalsize(env->getTotalsize() - (bucketSize_ - destinationBucket->numDataItems) * sizeof(dtype));
527     }
528     numDataItemsBuffered_ -= destinationBucket->numDataItems;
529
530     if (messageBuffers == personalizedBuffers_) {
531       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);
532     }
533     else {
534       this->thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
535     }
536     messageBuffers[flushIndex] = NULL;
537   }
538 }
539
540 template <class dtype>
541 void MeshStreamer<dtype>::flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers)
542 {
543
544     for (int i = 0; i < numBuffers; i++) {
545        if(messageBuffers[i] == NULL)
546            continue;
547        //flush all messages in i bucket
548        numDataItemsBuffered_ -= messageBuffers[i]->numDataItems;
549        if (messageBuffers == personalizedBuffers_) {
550          int destinationPe = myNodeIndex_ + (i - myRowIndex_) * numColumns_; 
551          clientProxy_[destinationPe].receiveCombinedData(messageBuffers[i]);
552        }
553        else {
554          for (int j = 0; j < messageBuffers[i]->numDataItems; j++) {
555            MeshStreamerMessage<dtype> *directMsg = 
556              new (0, 1) MeshStreamerMessage<dtype>();
557 #ifdef DEBUG_STREAMER
558            CkAssert(directMsg != NULL);
559 #endif
560            int destinationPe = messageBuffers[i]->destinationPes[j]; 
561            dtype dataItem = messageBuffers[i]->getDataItem(j);   
562            directMsg->addDataItem(dataItem);
563            clientProxy_[destinationPe].receiveCombinedData(directMsg);
564          }
565          delete messageBuffers[i];
566        }
567        messageBuffers[i] = NULL;
568     }
569
570 }
571
572 template <class dtype>
573 void MeshStreamer<dtype>::flushDirect(){
574
575     if (!isPeriodicFlushEnabled_ || 
576         1000 * (CkWallTimer() - timeOfLastSend_) >= progressPeriodInMs_) {
577       flushBuckets(planeBuffers_, numPlanes_);
578       flushBuckets(columnBuffers_, numColumns_);
579       flushBuckets(personalizedBuffers_, numRows_);
580     }
581
582     if (isPeriodicFlushEnabled_) {
583       timeOfLastSend_ = CkWallTimer();
584     }
585
586 #ifdef DEBUG_STREAMER
587     //CkPrintf("[%d] numDataItemsBuffered_: %d\n", CkMyPe(), numDataItemsBuffered_);
588     CkAssert(numDataItemsBuffered_ == 0); 
589 #endif
590
591 }
592
593 template <class dtype>
594 void periodicProgressFunction(void *MeshStreamerObj, double time) {
595
596   MeshStreamer<dtype> *properObj = 
597     static_cast<MeshStreamer<dtype>*>(MeshStreamerObj); 
598
599   if (properObj->isPeriodicFlushEnabled()) {
600     properObj->flushDirect();
601     properObj->registerPeriodicProgressFunction();
602   }
603 }
604
605 template <class dtype>
606 void MeshStreamer<dtype>::registerPeriodicProgressFunction() {
607   CcdCallFnAfter(periodicProgressFunction<dtype>, (void *) this, progressPeriodInMs_); 
608 }
609
610
611 #define CK_TEMPLATES_ONLY
612 #include "MeshStreamer.def.h"
613 #undef CK_TEMPLATES_ONLY
614
615 #endif