5650fd392771c3059da62f23a0c646e118daf647
[charm.git] / src / libs / ck-libs / MeshStreamer / MeshStreamer.h
1 #ifndef _MESH_STREAMER_H_
2 #define _MESH_STREAMER_H_
3
4 #include "MeshStreamer.decl.h"
5
6 // allocate more total buffer space then the maximum buffering limit but flush upon
7 // reaching totalBufferCapacity_
8 #define BUCKET_SIZE_FACTOR 4
9
10 //#define DEBUG_STREAMER 1
11
12 enum MeshStreamerMessageType {PlaneMessage, ColumnMessage, PersonalizedMessage};
13
14 /*
15 class LocalMessage : public CMessage_LocalMessage {
16 public:
17     int numDataItems; 
18     int dataItemSize; 
19     char *data;
20
21     LocalMessage(int dataItemSizeInBytes) {
22         numDataItems = 0; 
23         dataItemSize = dataItemSizeInBytes; 
24     }
25
26     int addDataItem(void *dataItem) {
27         CmiMemcpy(&data[numDataItems * dataItemSize], dataItem, dataItemSize);
28         return ++numDataItems; 
29     } 
30
31     void *getDataItem(int index) {
32         return (void *) (&data[index * dataItemSize]);  
33     }
34
35 };
36 */
37
38 template<class dtype>
39 class MeshStreamerMessage : public CMessage_MeshStreamerMessage<dtype> {
40 public:
41     int numDataItems;
42     int *destinationPes;
43     dtype *data;
44
45     MeshStreamerMessage(): numDataItems(0) {}   
46
47     int addDataItem(const dtype &dataItem) {
48         data[numDataItems] = dataItem;
49         return ++numDataItems; 
50     }
51
52     void markDestination(const int index, const int destinationPe) {
53         destinationPes[index] = destinationPe;
54     }
55
56     dtype getDataItem(const int index) {
57         return data[index];
58     }
59 };
60
61 template <class dtype>
62 class MeshStreamerClient : public Group {
63  public:
64      virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
65      virtual void process(dtype data)=0; 
66 };
67
68 template <class dtype>
69 class MeshStreamer : public Group {
70
71 private:
72     int bucketSize_; 
73     int totalBufferCapacity_;
74     int numDataItemsBuffered_;
75
76     int numNodes_; 
77     int numRows_; 
78     int numColumns_; 
79     int numPlanes_; 
80     int planeSize_;
81
82     CProxy_MeshStreamerClient<dtype> clientProxy_;
83     MeshStreamerClient<dtype> *clientObj_;
84
85     int myNodeIndex_;
86     int myPlaneIndex_;
87     int myColumnIndex_; 
88     int myRowIndex_;
89
90     CkCallback   userCallback_;
91     int yieldFlag_;
92
93     int progressPeriodInMs_; 
94     bool isPeriodicFlushEnabled_; 
95
96     MeshStreamerMessage<dtype> **personalizedBuffers_; 
97     MeshStreamerMessage<dtype> **columnBuffers_; 
98     MeshStreamerMessage<dtype> **planeBuffers_;
99
100     void determineLocation(const int destinationPe, int &row, int &column, 
101         int &plane, MeshStreamerMessageType &msgType);
102
103     void storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
104         const int bucketIndex, const int destinationPe, 
105         const int rowIndex, const int columnIndex, 
106         const int planeIndex,
107         const MeshStreamerMessageType msgType, const dtype &dataItem);
108
109     void flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
110                             const int numBuffers, const int myIndex, 
111                             const int dimensionFactor);
112
113 public:
114
115     MeshStreamer(int totalBufferCapacity, int numRows, 
116                  int numColumns, int numPlanes, 
117                  const CProxy_MeshStreamerClient<dtype> &clientProxy,
118                  int yieldFlag = 0, int progressPeriodInMs = -1);
119     ~MeshStreamer();
120
121       // entry
122     void insertData(const dtype &dataItem, const int destinationPe); 
123     void doneInserting(); 
124     void receiveAggregateData(MeshStreamerMessage<dtype> *msg);
125     // void receivePersonalizedData(MeshStreamerMessage<dtype> *msg);
126
127     void flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers);
128     void flushDirect();
129
130     bool isPeriodicFlushEnabled() {
131       return isPeriodicFlushEnabled_;
132     }
133       // non entry
134     void associateCallback(CkCallback &cb) { 
135               userCallback_ = cb;
136               CkStartQD(CkCallback(CkIndex_MeshStreamer<dtype>::flushDirect(), thisProxy));
137          }
138
139     void registerPeriodicProgressFunction();
140
141 };
142
143 template <class dtype>
144 void MeshStreamerClient<dtype>::receiveCombinedData(MeshStreamerMessage<dtype> *msg) {
145   for (int i = 0; i < msg->numDataItems; i++) {
146      dtype data = ((dtype*)(msg->data))[i];
147      process(data);
148   }
149   delete msg;
150 }
151
152 template <class dtype>
153 MeshStreamer<dtype>::MeshStreamer(int totalBufferCapacity, int numRows, 
154                            int numColumns, int numPlanes, 
155                            const CProxy_MeshStreamerClient<dtype> &clientProxy,
156                            int yieldFlag, int progressPeriodInMs): yieldFlag_(yieldFlag) {
157   // limit total number of messages in system to totalBufferCapacity
158   //   but allocate a factor BUCKET_SIZE_FACTOR more space to take
159   //   advantage of nonuniform filling of buckets
160   // the buffers for your own column and plane are never used
161   bucketSize_ = BUCKET_SIZE_FACTOR * totalBufferCapacity / (numRows + numColumns + numPlanes - 2); 
162   totalBufferCapacity_ = totalBufferCapacity;
163   numDataItemsBuffered_ = 0; 
164   numRows_ = numRows; 
165   numColumns_ = numColumns;
166   numPlanes_ = numPlanes; 
167   numNodes_ = CkNumPes(); 
168   clientProxy_ = clientProxy; 
169   clientObj_ = ((MeshStreamerClient<dtype> *)CkLocalBranch(clientProxy_));
170   progressPeriodInMs_ = progressPeriodInMs; 
171
172   personalizedBuffers_ = new MeshStreamerMessage<dtype> *[numRows];
173   for (int i = 0; i < numRows; i++) {
174     personalizedBuffers_[i] = NULL; 
175   }
176
177   columnBuffers_ = new MeshStreamerMessage<dtype> *[numColumns];
178   for (int i = 0; i < numColumns; i++) {
179     columnBuffers_[i] = NULL; 
180   }
181
182   planeBuffers_ = new MeshStreamerMessage<dtype> *[numPlanes]; 
183   for (int i = 0; i < numPlanes; i++) {
184     planeBuffers_[i] = NULL; 
185   }
186
187   // determine plane, column, and row location of this node
188   myNodeIndex_ = CkMyPe();
189   planeSize_ = numRows_ * numColumns_; 
190   myPlaneIndex_ = myNodeIndex_ / planeSize_; 
191   int indexWithinPlane = myNodeIndex_ - myPlaneIndex_ * planeSize_;
192   myRowIndex_ = indexWithinPlane / numColumns_;
193   myColumnIndex_ = indexWithinPlane - myRowIndex_ * numColumns_; 
194
195   if (progressPeriodInMs_ > 0) {
196     isPeriodicFlushEnabled_ = true; 
197     registerPeriodicProgressFunction();
198   }
199   else {
200     isPeriodicFlushEnabled_ = false; 
201   }
202
203 }
204
205 template <class dtype>
206 MeshStreamer<dtype>::~MeshStreamer() {
207
208   for (int i = 0; i < numRows_; i++)
209       delete personalizedBuffers_[i]; 
210
211   for (int i = 0; i < numColumns_; i++)
212       delete columnBuffers_[i]; 
213
214   for (int i = 0; i < numPlanes_; i++)
215       delete planeBuffers_[i]; 
216
217   delete[] personalizedBuffers_;
218   delete[] columnBuffers_;
219   delete[] planeBuffers_; 
220
221 }
222
223 template <class dtype>
224 void MeshStreamer<dtype>::determineLocation(const int destinationPe, int &rowIndex, 
225                                      int &columnIndex, int &planeIndex, 
226                                      MeshStreamerMessageType &msgType) {
227   
228   int nodeIndex, indexWithinPlane; 
229
230   nodeIndex = destinationPe;
231   planeIndex = nodeIndex / planeSize_; 
232   if (planeIndex != myPlaneIndex_) {
233     msgType = PlaneMessage;     
234   }
235   else {
236     indexWithinPlane = nodeIndex - planeIndex * planeSize_;
237     rowIndex = indexWithinPlane / numColumns_;
238     columnIndex = indexWithinPlane - rowIndex * numColumns_; 
239     if (columnIndex != myColumnIndex_) {
240       msgType = ColumnMessage; 
241     }
242     else {
243       msgType = PersonalizedMessage;
244     }
245   }
246
247 }
248
249 template <class dtype>
250 void MeshStreamer<dtype>::storeMessage(MeshStreamerMessage<dtype> ** const messageBuffers, 
251                                 const int bucketIndex, const int destinationPe, 
252                                 const int rowIndex, const int columnIndex, 
253                                 const int planeIndex, 
254                                 const MeshStreamerMessageType msgType, 
255                                 const dtype &dataItem) {
256
257   // allocate new message if necessary
258   if (messageBuffers[bucketIndex] == NULL) {
259     if (msgType == PersonalizedMessage) {
260       messageBuffers[bucketIndex] = 
261         new (0, bucketSize_) MeshStreamerMessage<dtype>();
262     }
263     else {
264       messageBuffers[bucketIndex] = 
265         new (bucketSize_, bucketSize_) MeshStreamerMessage<dtype>();
266     }
267 #ifdef DEBUG_STREAMER
268     CkAssert(messageBuffers[bucketIndex] != NULL);
269 #endif
270   }
271   
272   MeshStreamerMessage<dtype> *destinationBucket = messageBuffers[bucketIndex];
273   
274   int numBuffered = destinationBucket->addDataItem(dataItem); 
275   if (msgType != PersonalizedMessage) {
276     destinationBucket->markDestination(numBuffered-1, destinationPe);
277   }
278   numDataItemsBuffered_++;
279   // copy data into message and send if buffer is full
280   if (numBuffered == bucketSize_) {
281     int destinationIndex;
282     CProxy_MeshStreamer<dtype> thisProxy(thisgroup);
283     switch (msgType) {
284
285     case PlaneMessage:
286       destinationIndex = 
287         myNodeIndex_ + (planeIndex - myPlaneIndex_) * planeSize_;      
288       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
289       break;
290     case ColumnMessage:
291       destinationIndex = myNodeIndex_ + (columnIndex - myColumnIndex_);
292       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
293       break;
294     case PersonalizedMessage:
295       destinationIndex = myNodeIndex_ + (rowIndex - myRowIndex_) * numColumns_;
296       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);      
297       //      thisProxy[destinationIndex].receivePersonalizedData(destinationBucket);
298       break;
299     default: 
300       CkError("Incorrect MeshStreamer message type\n");
301       break;
302     }
303     messageBuffers[bucketIndex] = NULL;
304     numDataItemsBuffered_ -= numBuffered; 
305   }
306
307   if (numDataItemsBuffered_ == totalBufferCapacity_) {
308
309     flushLargestBucket(personalizedBuffers_, numRows_, myRowIndex_, numColumns_);
310     flushLargestBucket(columnBuffers_, numColumns_, myColumnIndex_, 1);
311     flushLargestBucket(planeBuffers_, numPlanes_, myPlaneIndex_, planeSize_);
312
313   }
314
315 }
316
317 template <class dtype>
318 void MeshStreamer<dtype>::insertData(const dtype &dataItem, const int destinationPe) {
319   static int count = 0;
320
321   if (destinationPe == CkMyPe()) {
322     clientObj_->process(dataItem);
323     return;
324   }
325
326   int planeIndex, columnIndex, rowIndex; // location of destination
327   int indexWithinPlane; 
328   MeshStreamerMessageType msgType; 
329
330   determineLocation(destinationPe, rowIndex, columnIndex, planeIndex, msgType);
331
332   // determine which array of buffers is appropriate for this message
333   MeshStreamerMessage<dtype> **messageBuffers;
334   int bucketIndex; 
335
336   switch (msgType) {
337   case PlaneMessage:
338     messageBuffers = planeBuffers_; 
339     bucketIndex = planeIndex; 
340     break;
341   case ColumnMessage:
342     messageBuffers = columnBuffers_; 
343     bucketIndex = columnIndex; 
344     break;
345   case PersonalizedMessage:
346     messageBuffers = personalizedBuffers_; 
347     bucketIndex = rowIndex; 
348     break;
349   default: 
350     CkError("Unrecognized MeshStreamer message type\n");
351     break;
352   }
353
354   storeMessage(messageBuffers, bucketIndex, destinationPe, rowIndex, 
355                columnIndex, planeIndex, msgType, dataItem);
356
357     // release control to scheduler, assume caller is threaded entry
358   if (yieldFlag_ && ++count % 1024 == 0) CthYield();
359 }
360
361 template <class dtype>
362 void MeshStreamer<dtype>::doneInserting() {
363   // disable periodic flushing
364   isPeriodicFlushEnabled_ = false; 
365 }
366
367 template <class dtype>
368 void MeshStreamer<dtype>::receiveAggregateData(MeshStreamerMessage<dtype> *msg) {
369
370   int rowIndex, columnIndex, planeIndex, destinationPe; 
371   MeshStreamerMessageType msgType;   
372
373   for (int i = 0; i < msg->numDataItems; i++) {
374     destinationPe = msg->destinationPes[i];
375     dtype dataItem = msg->getDataItem(i);
376     determineLocation(destinationPe, rowIndex, columnIndex, 
377                       planeIndex, msgType);
378 #ifdef DEBUG_STREAMER
379     CkAssert(planeIndex == myPlaneIndex_);
380
381     if (msgType == PersonalizedMessage) {
382       CkAssert(columnIndex == myColumnIndex_);
383     }
384 #endif    
385
386     MeshStreamerMessage<dtype> **messageBuffers;
387     int bucketIndex; 
388
389     switch (msgType) {
390     case ColumnMessage:
391       messageBuffers = columnBuffers_; 
392       bucketIndex = columnIndex; 
393       break;
394     case PersonalizedMessage:
395       messageBuffers = personalizedBuffers_; 
396       bucketIndex = rowIndex; 
397       break;
398     default: 
399       CkError("Incorrect MeshStreamer message type\n");
400       break;
401     }
402
403     storeMessage(messageBuffers, bucketIndex, destinationPe, rowIndex, 
404                  columnIndex, planeIndex, msgType, dataItem);
405     
406   }
407
408   delete msg;
409
410 }
411
412 /*
413 void MeshStreamer::receivePersonalizedData(MeshStreamerMessage *msg) {
414
415   // sort data items into messages for each core on this node
416
417   LocalMessage *localMsgs[numPesPerNode_];
418   int dataSize = bucketSize_ * dataItemSize_;
419
420   for (int i = 0; i < numPesPerNode_; i++) {
421     localMsgs[i] = new (dataSize) LocalMessage(dataItemSize_);
422   }
423
424   int destinationPe;
425   for (int i = 0; i < msg->numDataItems; i++) {
426
427     destinationPe = msg->destinationPes[i]; 
428     void *dataItem = msg->getDataItem(i);   
429     localMsgs[destinationPe % numPesPerNode_]->addDataItem(dataItem);
430
431   }
432
433   for (int i = 0; i < numPesPerNode_; i++) {
434     if (localMsgs[i]->numDataItems > 0) {
435       clientProxy_[myNodeIndex_ * numPesPerNode_ + i].receiveCombinedData(localMsgs[i]);
436     }
437     else {
438       delete localMsgs[i];
439     }
440   }
441
442   delete msg; 
443
444 }
445 */
446
447 template <class dtype>
448 void MeshStreamer<dtype>::flushLargestBucket(MeshStreamerMessage<dtype> ** const messageBuffers,
449                                       const int numBuffers, const int myIndex, 
450                                       const int dimensionFactor) {
451
452   int flushIndex, maxSize, destinationIndex;
453   MeshStreamerMessage<dtype> *destinationBucket; 
454   maxSize = 0;
455   for (int i = 0; i < numBuffers; i++) {
456     if (messageBuffers[i] != NULL && messageBuffers[i]->numDataItems > maxSize) {
457       maxSize = messageBuffers[i]->numDataItems;
458       flushIndex = i;
459     } 
460   }
461   if (maxSize > 0) {
462     destinationBucket = messageBuffers[flushIndex];
463     destinationIndex = myNodeIndex_ + (flushIndex - myIndex) * dimensionFactor;
464
465     if (destinationBucket->numDataItems < bucketSize_) {
466       // not sending the full buffer, shrink the message size
467       envelope *env = UsrToEnv(destinationBucket);
468       env->setTotalsize(env->getTotalsize() - (bucketSize_ - destinationBucket->numDataItems) * sizeof(dtype));
469     }
470     numDataItemsBuffered_ -= destinationBucket->numDataItems;
471
472     if (messageBuffers == personalizedBuffers_) {
473       clientProxy_[destinationIndex].receiveCombinedData(destinationBucket);
474     }
475     else {
476       CProxy_MeshStreamer<dtype> thisProxy(thisgroup);
477       thisProxy[destinationIndex].receiveAggregateData(destinationBucket);
478     }
479     messageBuffers[flushIndex] = NULL;
480   }
481 }
482
483 template <class dtype>
484 void MeshStreamer<dtype>::flushBuckets(MeshStreamerMessage<dtype> **messageBuffers, const int numBuffers)
485 {
486
487     for (int i = 0; i < numBuffers; i++) {
488        if(messageBuffers[i] == NULL)
489            continue;
490        //flush all messages in i bucket
491        numDataItemsBuffered_ -= messageBuffers[i]->numDataItems;
492        if (messageBuffers == personalizedBuffers_) {
493          int destinationPe = myNodeIndex_ + (i - myRowIndex_) * numColumns_; 
494          clientProxy_[destinationPe].receiveCombinedData(messageBuffers[i]);
495        }
496        else {
497          for (int j = 0; j < messageBuffers[i]->numDataItems; j++) {
498            MeshStreamerMessage<dtype> *directMsg = 
499              new (0, 1) MeshStreamerMessage<dtype>();
500 #ifdef DEBUG_STREAMER
501            CkAssert(directMsg != NULL);
502 #endif
503            int destinationPe = messageBuffers[i]->destinationPes[j]; 
504            dtype dataItem = messageBuffers[i]->getDataItem(j);   
505            directMsg->addDataItem(dataItem);
506            clientProxy_[destinationPe].receiveCombinedData(directMsg);
507          }
508          delete messageBuffers[i];
509        }
510        messageBuffers[i] = NULL;
511     }
512
513 }
514
515 template <class dtype>
516 void MeshStreamer<dtype>::flushDirect(){
517     flushBuckets(planeBuffers_, numPlanes_);
518     flushBuckets(columnBuffers_, numColumns_);
519     flushBuckets(personalizedBuffers_, numRows_);
520
521 #ifdef DEBUG_STREAMER
522     //CkPrintf("[%d] numDataItemsBuffered_: %d\n", CkMyPe(), numDataItemsBuffered_);
523     CkAssert(numDataItemsBuffered_ == 0); 
524 #endif
525
526     if (!userCallback_.isInvalid()) {
527         CkStartQD(userCallback_);
528         userCallback_ = CkCallback();      // nullify the current callback
529     }
530 }
531
532 template <class dtype>
533 void periodicProgressFunction(void *MeshStreamerObj, double time) {
534
535   MeshStreamer<dtype> *properObj = static_cast<MeshStreamer<dtype>*>(MeshStreamerObj); 
536
537   properObj->flushDirect();
538
539   if (properObj->isPeriodicFlushEnabled()) {
540     properObj->registerPeriodicProgressFunction();
541   }
542 }
543
544 template <class dtype>
545 void MeshStreamer<dtype>::registerPeriodicProgressFunction() {
546   CcdCallFnAfter(periodicProgressFunction<dtype>, (void *) this, progressPeriodInMs_); 
547 }
548
549
550 #define CK_TEMPLATES_ONLY
551 #include "MeshStreamer.def.h"
552 #undef CK_TEMPLATES_ONLY
553
554 #endif