NDMeshStreamer: adding termination scheme using completion detection library
authorLukasz Wesolowski <wesolwsk@illinois.edu>
Sun, 11 Mar 2012 02:28:31 +0000 (20:28 -0600)
committerLukasz Wesolowski <wesolwsk@illinois.edu>
Sun, 11 Mar 2012 02:31:21 +0000 (20:31 -0600)
src/libs/ck-libs/NDMeshStreamer/NDMeshStreamer.ci
src/libs/ck-libs/NDMeshStreamer/NDMeshStreamer.h

index 629d992c75c2d59189b02933ead7bdab52a1d391..c1c76b0ae4d8bc9624eb39f92df52cb198e54fd7 100644 (file)
@@ -1,4 +1,5 @@
 module NDMeshStreamer {
+  extern module completion;
 
   include "DataItemTypes.h";
 
@@ -14,7 +15,7 @@ module NDMeshStreamer {
 
   template<class dtype> array [1D] MeshStreamerArrayClient {
     // entry void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
-    entry void process(dtype &data);
+    entry void receiveRedeliveredItem(dtype data);
   };
 
   template<class dtype> 
@@ -22,6 +23,8 @@ module NDMeshStreamer {
     entry void receiveAlongRoute(MeshStreamerMessage<dtype> *msg);
     entry void flushDirect();
     entry void finish(CkReductionMsg *msg);
+    entry void associateCallback(CkCallback startCb, CkCallback endCb, 
+                                CProxy_CompletionDetector detector);
   };
 
   template<class dtype>
@@ -30,7 +33,7 @@ module NDMeshStreamer {
          int totalBufferCapacity, int numDimensions, 
          int dimensionSizes[numDimensions], 
          const CProxy_MeshStreamerGroupClient<dtype> &clientProxy,
-         int yieldFlag = 0, double progressPeriodInMs = -1.0);
+         bool yieldFlag = 0, double progressPeriodInMs = -1.0);
   };
 
 
@@ -40,7 +43,7 @@ module NDMeshStreamer {
          int totalBufferCapacity, int numDimensions, 
          int dimensionSizes[numDimensions],
          const CProxy_MeshStreamerArrayClient<dtype> &clientProxy,
-         int yieldFlag = 0, double progressPeriodInMs = -1.0);
+         bool yieldFlag = 0, double progressPeriodInMs = -1.0);
 
     entry void receiveArrayData(
               MeshStreamerMessage<ArrayDataItem<dtype> > *msg); 
index af1d653e30da53738987e93cc0e625026278abe6..4be5d5a0baf1a5cc9e7b8ebd9807e3366a7f63e6 100644 (file)
@@ -4,6 +4,7 @@
 #include <algorithm>
 #include "NDMeshStreamer.decl.h"
 #include "DataItemTypes.h"
+#include "completion.h"
 
 // allocate more total buffer space than the maximum buffering limit but flush 
 //   upon reaching totalBufferCapacity_
@@ -43,21 +44,47 @@ public:
 };
 
 template <class dtype>
-class MeshStreamerGroupClient : public CBase_MeshStreamerGroupClient<dtype> {
+class MeshStreamerClient {
+ protected:
+  CompletionDetector *detectorLocalObj_;
  public:
-     virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
-     virtual void process(dtype &data)=0; 
+  // would like to make it pure virtual but charm will try to
+  // instantiate the abstract class, leading to errors
+  virtual void process(dtype &data) {};     
+  void setDetector(CompletionDetector *detectorLocalObj) {
+    detectorLocalObj_ = detectorLocalObj;
+  }
 };
 
 template <class dtype>
-class MeshStreamerArrayClient : public CBase_MeshStreamerArrayClient<dtype> {
+class MeshStreamerGroupClient : public CBase_MeshStreamerGroupClient<dtype>,
+  public MeshStreamerClient<dtype> {
  public:
-     // virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
-  // would like to make it pure virtual but charm will try to
-  // instantiate the abstract class, leading to errors
-  virtual void process(dtype &data) {} //=0; 
+
+  virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg) {
+    for (int i = 0; i < msg->numDataItems; i++) {
+      dtype &data = msg->getDataItem(i);
+      process(data);
+    }
+    MeshStreamerClient<dtype>::detectorLocalObj_->consume(msg->numDataItems);
+    delete msg;
+  }
+};
+
+template <class dtype>
+class MeshStreamerArrayClient :  public CBase_MeshStreamerArrayClient<dtype>, 
+  public MeshStreamerClient<dtype>
+   {
+ public:
+
+  // virtual void receiveCombinedData(MeshStreamerMessage<dtype> *msg);
   MeshStreamerArrayClient() {}
   MeshStreamerArrayClient(CkMigrateMessage *msg) {}
+  void receiveRedeliveredItem(dtype data) {
+    MeshStreamerClient<dtype>::detectorLocalObj_->consume();
+    process(data);
+  }
+
 };
 
 template <class dtype>
@@ -77,15 +104,16 @@ private:
     int *myLocationIndex_;
 
     CkCallback   userCallback_;
-    int yieldFlag_;
+    bool yieldFlag_;
 
     double progressPeriodInMs_; 
     bool isPeriodicFlushEnabled_; 
     double timeOfLastSend_; 
 
-
     MeshStreamerMessage<dtype> ***dataBuffers_;
 
+    CProxy_CompletionDetector detector_;
+
 #ifdef CACHE_LOCATIONS
     MeshLocation *cachedLocations_;
     bool *isCached_; 
@@ -103,10 +131,15 @@ private:
 
     virtual void localDeliver(dtype &dataItem) = 0; 
 
+    virtual int numElementsInClient() = 0;
+
+    virtual void setDetectorInClient() = 0;
+
     void flushLargestBuffer();
 
 protected:
 
+    CompletionDetector *detectorLocalObj_;
     virtual int copyDataItemIntoMessage(
                MeshStreamerMessage<dtype> *destinationBuffer, 
                void *dataItemHandle, bool copyIndirectly = false);
@@ -115,7 +148,7 @@ public:
 
     MeshStreamer(int totalBufferCapacity, int numDimensions, 
                 int *dimensionSizes,
-                 int yieldFlag = 0, double progressPeriodInMs = -1.0);
+                 bool yieldFlag = 0, double progressPeriodInMs = -1.0);
     ~MeshStreamer();
 
       // entry
@@ -137,6 +170,27 @@ public:
                             this->thisProxy));
       }
     }
+
+    void associateCallback(CkCallback startCb, CkCallback endCb, 
+                          CProxy_CompletionDetector detector) {
+      userCallback_ = endCb; 
+      static CkCallback finish(CkIndex_MeshStreamer<dtype>::finish(NULL), this->thisProxy);
+      detector_ = detector;      
+      detectorLocalObj_ = detector_.ckLocalBranch();
+      setDetectorInClient();
+      detectorLocalObj_->start_detection(numElementsInClient(), 
+                                        startCb, finish , 0);
+
+      if (progressPeriodInMs_ <= 0) {
+       CkPrintf("Using completion detection in NDMeshStreamer requires"
+                " setting a valid periodic flush period. Defaulting"
+                 " to 10 ms\n");
+       progressPeriodInMs_ = 10;
+      }
+      enablePeriodicFlushing();
+
+    }
+
     void flushAllBuffers();
     void registerPeriodicProgressFunction();
 
@@ -146,23 +200,18 @@ public:
       isPeriodicFlushEnabled_ = true; 
       registerPeriodicProgressFunction();
     }
-};
 
-template <class dtype>
-void MeshStreamerGroupClient<dtype>::receiveCombinedData(
-                                MeshStreamerMessage<dtype> *msg) {
-  for (int i = 0; i < msg->numDataItems; i++) {
-    dtype &data = msg->getDataItem(i);
-    process(data);
-  }
-  delete msg;
-}
+    void done() {
+      detectorLocalObj_->done();
+    }
+
+};
 
 template <class dtype>
 MeshStreamer<dtype>::MeshStreamer(
                     int totalBufferCapacity, int numDimensions, 
                     int *dimensionSizes, 
-                    int yieldFlag, 
+                    bool yieldFlag, 
                      double progressPeriodInMs)
  :numDimensions_(numDimensions), 
   totalBufferCapacity_(totalBufferCapacity), 
@@ -217,6 +266,7 @@ MeshStreamer<dtype>::MeshStreamer(
   }
 
   isPeriodicFlushEnabled_ = false; 
+  detectorLocalObj_ = NULL;
 
 #ifdef CACHE_LOCATIONS
   cachedLocations_ = new MeshLocation[numMembers_];
@@ -372,7 +422,6 @@ void MeshStreamer<dtype>::insertData(void *dataItemHandle, int destinationPe) {
   MeshLocation destinationLocation = determineLocation(destinationPe);
   storeMessage(destinationPe, destinationLocation, dataItemHandle, 
               copyIndirectly); 
-
   // release control to scheduler if requested by the user, 
   //   assume caller is threaded entry
   if (yieldFlag_ && ++count == 1024) {
@@ -386,6 +435,7 @@ template <class dtype>
 inline
 void MeshStreamer<dtype>::insertData(dtype &dataItem, int destinationPe) {
 
+  detectorLocalObj_->produce();
   if (destinationPe == CkMyPe()) {
     // copying here is necessary - user code should not be 
     // passed back a reference to the original item
@@ -405,12 +455,12 @@ void MeshStreamer<dtype>::doneInserting() {
 
 template <class dtype>
 void MeshStreamer<dtype>::finish(CkReductionMsg *msg) {
-
   isPeriodicFlushEnabled_ = false; 
-  flushDirect();
+  // flushDirect();
 
   if (!userCallback_.isInvalid()) {
-    CkStartQD(userCallback_);
+    this->contribute(userCallback_);
+    //CkStartQD(userCallback_);
     userCallback_ = CkCallback();      // nullify the current callback
   }
 
@@ -587,6 +637,16 @@ private:
 
   void localDeliver(dtype &dataItem) {
     clientObj_->process(dataItem);
+    MeshStreamer<dtype>::detectorLocalObj_->consume();
+  }
+
+  int numElementsInClient() {
+    // client is a group - there is one element per PE
+    return CkNumPes();
+  }
+
+  void setDetectorInClient() {
+    clientObj_->setDetector(MeshStreamer<dtype>::detectorLocalObj_);
   }
 
 public:
@@ -594,7 +654,7 @@ public:
   GroupMeshStreamer(int totalBufferCapacity, int numDimensions,
                    int *dimensionSizes, 
                    const CProxy_MeshStreamerGroupClient<dtype> &clientProxy,
-                   int yieldFlag = 0, double progressPeriodInMs = -1.0)
+                   bool yieldFlag = 0, double progressPeriodInMs = -1.0)
    :MeshStreamer<dtype>(totalBufferCapacity, numDimensions, dimensionSizes, 
                          yieldFlag, progressPeriodInMs) 
   {
@@ -603,7 +663,6 @@ public:
       ((MeshStreamerGroupClient<dtype> *)CkLocalBranch(clientProxy_));
   }
 
-
 };
 
 template <class dtype>
@@ -631,10 +690,24 @@ private:
 
     if (clientObjs_[arrayId] != NULL) {
       clientObjs_[arrayId]->process(packedDataItem.dataItem);
+      MeshStreamer<ArrayDataItem<dtype> >::detectorLocalObj_->consume();
     }
     else { 
       // array element is no longer present locally - redeliver using proxy
-      clientProxy_[arrayId].process(packedDataItem.dataItem);
+      clientProxy_[arrayId].receiveRedeliveredItem(packedDataItem.dataItem);
+    }
+  }
+
+  int numElementsInClient() {
+    return numArrayElements_;
+  }
+
+  void setDetectorInClient() {
+    for (int i = 0; i < numArrayElements_; i++) {
+      if (clientObjs_[i] != NULL) {
+       clientObjs_[i]->setDetector(
+                        MeshStreamer<ArrayDataItem<dtype> >::detectorLocalObj_);
+      }
     }
   }
 
@@ -648,7 +721,7 @@ public:
   ArrayMeshStreamer(int totalBufferCapacity, int numDimensions,
                    int *dimensionSizes, 
                    const CProxy_MeshStreamerArrayClient<dtype> &clientProxy,
-                   int yieldFlag = 0, double progressPeriodInMs = -1.0)
+                   bool yieldFlag = 0, double progressPeriodInMs = -1.0)
     :MeshStreamer<ArrayDataItem<dtype> >(totalBufferCapacity, numDimensions, 
                                        dimensionSizes, yieldFlag, 
                                        progressPeriodInMs) 
@@ -666,7 +739,8 @@ public:
 #ifdef CACHE_ARRAY_METADATA
     destinationPes_ = new int[numArrayElements_];
     isCachedArrayMetadata_ = new bool[numArrayElements_];
-    std::fill(isCachedArrayMetadata_, isCachedArrayMetadata_ + numArrayElements_, false);
+    std::fill(isCachedArrayMetadata_, 
+             isCachedArrayMetadata_ + numArrayElements_, false);
 #endif
   }
 
@@ -688,6 +762,7 @@ public:
 
   void insertData(dtype &dataItem, int arrayIndex) {
 
+    MeshStreamer<ArrayDataItem<dtype> >::detectorLocalObj_->produce();
     int destinationPe; 
 #ifdef CACHE_ARRAY_METADATA
   if (isCachedArrayMetadata_[arrayIndex]) {