Add reducer functions for user-defined functors in the entry method templates eg
authorRamprasad Venkataraman <ramv@illinois.edu>
Mon, 2 Apr 2012 18:03:20 +0000 (13:03 -0500)
committerPhil Miller <mille121@illinois.edu>
Thu, 5 Apr 2012 21:17:55 +0000 (16:17 -0500)
tests/charm++/method_templates/pgm.C
tests/charm++/method_templates/pgm.ci
tests/charm++/method_templates/utils.h

index b330ae8c300a1391c2edbd3a469bbac924fdb57d..e3b7ce3e4e5169c7813eb2d16bb8d3e85fd978ba 100644 (file)
@@ -19,6 +19,14 @@ void register_instantiations()
 };
 
 
+// Register reducer functions
+void register_reducers()
+{
+    CkReduction::reducerType countReducer = CkReduction::addReducer(count< std::less<int> >::reduce_count);
+    CkReduction::reducerType avgReducer   = CkReduction::addReducer(avg::reduce_avg);
+}
+
+
 // Test driver
 class pgm : public CBase_pgm
 {
index 0b87983ecb0f7dd97b7119a7d4c6f2945f08a8a8..96feaeb7313d9dc948fa594347bc9660d9905332 100644 (file)
@@ -1,6 +1,7 @@
 mainmodule client  {
   extern module mylib;
   initproc register_instantiations();
+  initproc register_reducers();
 
   mainchare pgm {
     entry pgm (CkArgMsg *m);
index 3bc3d6b647006d51b60b924dc72112344c5a789b..2065a215f4d1624e6d407e458c4cc94f1e1559e2 100644 (file)
@@ -10,20 +10,39 @@ class count {
     public:
         //
         count(const int _t=0): threshold(_t), num(0) {}
+
         // Operate on an input element
         inline void operator() (int* first, int *last)
         {
             for (int *ptr = first; ptr != last; ptr++)
                 if (c(*ptr, threshold)) num++;
         }
+
         // Serialize the internals
         void pup(PUP::er &p) { p | threshold; p | num; }
+
         // Spit results to ostream
         friend std::ostream& operator<< (std::ostream& out, const count& obj) {
             out << "threshold = "<< obj.threshold << "; "
                 << "num = " << obj.num;
             return out;
         }
+
+        // Reducer function to accumulate across multiple count objects
+        static CkReductionMsg* reduce_count(int nMsg, CkReductionMsg **msgs)
+        {
+            CkAssert(nMsg > 0);
+            count<cmp>* result = (count<cmp>*) msgs[0]->getData();
+            for (int i = 1; i < nMsg; ++i)
+            {
+                count<cmp> *contrib = (count<cmp>*) msgs[i]->getData();
+                // Compare the thresholds themselves and use the appropriate one
+                if ( result->c(contrib->threshold, result->threshold) )
+                    result->threshold = contrib->threshold;
+                result->num += contrib->num;
+            }
+            return CkReductionMsg::buildNew(sizeof(count<cmp>), result);
+        }
 };
 
 
@@ -33,6 +52,7 @@ class avg {
         int sum, num;
     public:
         avg(): sum(0), num(0) {}
+
         // Operate on an input element
         inline void operator() (int* first, int *last)
         {
@@ -40,8 +60,10 @@ class avg {
             for (int *ptr = first; ptr != last; ptr++)
                 sum += *ptr;
         }
+
         // Serialize the internals
         void pup(PUP::er &p) { p | sum; p | num; }
+
         // Spit results to ostream
         friend std::ostream& operator<< (std::ostream& out, const avg& obj) {
             out << "num = " << obj.num << "; "
@@ -49,8 +71,20 @@ class avg {
                 << "avg = " << ( obj.num ? (double)obj.sum/obj.num : obj.sum );
             return out;
         }
-};
-
 
+        // Reducer function to accumulate across multiple avg objects
+        static CkReductionMsg* reduce_avg(int nMsg, CkReductionMsg **msgs)
+        {
+            CkAssert(nMsg > 0);
+            avg* result = (avg*) msgs[0]->getData();
+            for (int i = 1; i < nMsg; ++i)
+            {
+                avg *contrib = (avg*) msgs[i]->getData();
+                result->sum += contrib->sum;
+                result->num += contrib->num;
+            }
+            return CkReductionMsg::buildNew(sizeof(avg), result);
+        }
+};
 
 #endif // UTILS_H