Enforce operations a bit more through handles
authorPhil Miller <mille121@illinois.edu>
Wed, 29 Apr 2009 17:25:26 +0000 (12:25 -0500)
committerPhil Miller <mille121@illinois.edu>
Thu, 10 Dec 2009 22:22:56 +0000 (16:22 -0600)
src/libs/ck-libs/multiphaseSharedArrays/msa-distArray.h

index 2b7d955ba25bf2705afeece88edb13f9995939e6..c75108bb3ccc846d5ee2942d1f72b3226f72b78b 100644 (file)
@@ -7,6 +7,28 @@
 
 struct MSA_InvalidHandle { };
 
+template <typename ENTRY>
+class Writable
+{
+    ENTRY &e;
+    
+public:
+    Writable(ENTRY &e_) : e(e_) {}
+    inline const ENTRY& operator= (const ENTRY& rhs) { e = rhs; }
+};
+
+template <typename ENTRY, class ENTRY_OPS_CLASS>
+class Accumulable
+{
+    ENTRY &e;
+    
+public:
+    Accumulable(ENTRY &e_) : e(e_) {}
+    void operator+=(const ENTRY &rhs_)
+        { ENTRY_OPS_CLASS::accumulate(e, rhs_); }
+};
+
+
 /**
    The MSA1D class is a handle to a distributed shared array of items
    of data type ENTRY. There are nEntries total numer of ENTRY's, with
@@ -87,6 +109,7 @@ public:
             return Handle::msa.get(idx); 
         }
         inline const ENTRY& operator[](unsigned int idx) { return get(idx); }
+        inline const ENTRY& operator()(unsigned int idx) { return get(idx); }
         inline const ENTRY& get2(unsigned int idx)
         {
             checkValid();
@@ -102,11 +125,13 @@ public:
             : Handle(msa_) { }
 
     public:
-        inline ENTRY& set(unsigned int idx)
+        inline Writable<ENTRY> set(unsigned int idx)
         {
             Handle::checkValid();
-            return Handle::msa.set(idx);
+            return Writable<ENTRY>(Handle::msa.set(idx));
         }
+        inline Writable<ENTRY> operator()(unsigned int idx)
+            { return set(idx); }
     };
 
     class Accum : public Handle
@@ -117,10 +142,10 @@ public:
             : Handle(msa_) { }
         using Handle::checkInvalidate;
     public:
-        inline ENTRY& accumulate(unsigned int idx)
+        inline Accumulable<ENTRY, ENTRY_OPS_CLASS> accumulate(unsigned int idx)
         { 
             Handle::checkValid();
-            return Handle::msa.accumulate(idx);
+            return Accumulable<ENTRY, ENTRY_OPS_CLASS>(Handle::msa.accumulate(idx));
         }
         inline void accumulate(unsigned int idx, const ENTRY& ent)
         {
@@ -136,6 +161,9 @@ public:
                     Handle::msa.accumulate(idx, *e);
                 }
         }
+
+        inline Accumulable<ENTRY, ENTRY_OPS_CLASS> operator() (unsigned int idx)
+            { return accumulate(idx); }
     };
 
 protected:
@@ -412,7 +440,7 @@ protected:
     ///   Merges together accumulates from different threads.
     inline void accumulate(unsigned int idx, const ENTRY& ent)
     {
-        accumulate(idx)+=ent;
+        ENTRY_OPS_CLASS::accumulate(accumulate(idx),ent);
     }
 
     /// Synchronize reads and writes across the entire array.
@@ -489,6 +517,12 @@ public:
             Handle::checkValid();
             return Handle::msa.get2(row, col);
         }
+
+        inline const ENTRY& operator() (unsigned int row, unsigned int col)
+            {
+                return get(row,col);
+            }
+
     };
 
     class Write : public Handle
@@ -499,10 +533,10 @@ public:
             :  Handle(msa_) { }
 
     public: 
-        inline ENTRY& set(unsigned int row, unsigned int col)
+        inline Writable<ENTRY> set(unsigned int row, unsigned int col)
         {
             Handle::checkValid();
-            return Handle::msa.set(row, col);
+            return Writable<ENTRY>(Handle::msa.set(row, col));
         }
     };