MSA Example: Make 'matmul' compile with new API
authorPhil Miller <mille121@illinois.edu>
Thu, 10 Sep 2009 23:08:08 +0000 (18:08 -0500)
committerPhil Miller <mille121@illinois.edu>
Thu, 10 Dec 2009 22:23:01 +0000 (16:23 -0600)
The MSA API changed. Change the matrix multiplication
benchmark/example with all of its variations to match.

examples/multiphaseSharedArrays/matmul/t2d.C

index e277cf386f154233c882ec637402bc6f541fa429..82d8257c9353bbe53c65d8399ebdf72051aaf6cd 100644 (file)
@@ -48,7 +48,7 @@ protected:
 public:
     t2d(CkArgMsg* m)
     {
-        // Usage: a.out [number_of_worker_threads [max_bytes]]
+        // Usage: a.out number_of_worker_threads max_bytes ROWS1 ROWS2 COLS2 DECOMP-D TIMING-DETAIL?
         if(m->argc >1 ) NUM_WORKERS=atoi(m->argv[1]);
         if(m->argc >2 ) bytes=atoi(m->argv[2]);
         if(m->argc >3 ) ROWS1=atoi(m->argv[3]);
@@ -219,8 +219,11 @@ private:
 
 protected:
     MSA2DRowMjr arr1;       // row major
+       MSA2DRowMjr::Read *h1;
     MSA2DColMjr arr2;       // column major
+       MSA2DColMjr::Read *h2;
     MSA2DRowMjrC prod;       // product matrix
+       MSA2DRowMjrC::Handle *hp;
 
     unsigned int rows1, rows2, cols1, cols2, numWorkers;
 
@@ -231,7 +234,7 @@ protected:
         prod.enroll(numWorkers); // barrier
     }
 
-    void FillArrays()
+    void FillArrays(MSA2DRowMjr::Write &w1, MSA2DColMjr::Write &w2)
     {
         // fill in our portion of the array
         unsigned int rowStart, rowEnd, colStart, colEnd;
@@ -241,15 +244,15 @@ protected:
         // fill them in with 1
         for(unsigned int r = rowStart; r <= rowEnd; r++)
             for(unsigned int c = 0; c < cols1; c++)
-                arr1.set(r, c) = 1.0;
+                w1.set(r, c) = 1.0;
 
         for(unsigned int c = colStart; c <= colEnd; c++)
             for(unsigned int r = 0; r < rows2; r++)
-                arr2.set(r, c) = 1.0;
+                w2.set(r, c) = 1.0;
 
     }
 
-    void FillArrays2D()
+    void FillArrays2D(MSA2DRowMjr::Write &w1, MSA2DColMjr::Write &w2)
     {
         unsigned int rowStart, rowEnd, colStart, colEnd;
         unsigned int r, c;
@@ -262,7 +265,7 @@ protected:
         // fill them in with 1
         for(r = rowStart; r <= rowEnd; r++)
             for(c = colStart; c <= colEnd; c++)
-                arr1.set(r, c) = 1.0;
+                w1.set(r, c) = 1.0;
 
         // fill in our portion of the B matrix
         GetMyIndices(rows2, toX(), numWorkers2D(), rowStart, rowEnd);
@@ -271,16 +274,11 @@ protected:
         // fill them in with 1
         for(r = rowStart; r <= rowEnd; r++)
             for(c = colStart; c <= colEnd; c++)
-                arr2.set(r, c) = 1.0;
+                w2.set(r, c) = 1.0;
     }
 
-    void SyncArrays()
-    {
-        arr1.sync();
-        arr2.sync();
-    }
-
-    void TestResults(bool prod_test=true)
+    void TestResults(MSA2DRowMjr::Read &r1, MSA2DColMjr::Read &r2, MSA2DRowMjrC::Read &rp,
+                                        bool prod_test=true)
     {
         int errors = 0;
         bool ok=true;
@@ -289,8 +287,8 @@ protected:
         ok=true;
         for(unsigned int r = 0; ok && r < rows1; r++) {
             for(unsigned int c = 0; ok && c < cols1; c++) {
-                if(notequal(arr1.get(r, c), 1.0)) {
-                    ckout << "[" << CkMyPe() << "," << thisIndex << "] arr1 -- Illegal element at (" << r << "," << c << ") " << arr1.get(r,c) << endl;
+                if(notequal(r1.get(r, c), 1.0)) {
+                    ckout << "[" << CkMyPe() << "," << thisIndex << "] arr1 -- Illegal element at (" << r << "," << c << ") " << r1.get(r,c) << endl;
                     ok=false;
                     errors++;
                 }
@@ -300,8 +298,8 @@ protected:
         ok=true;
         for(unsigned int c = 0; ok && c < cols2; c++) {
             for(unsigned int r = 0; ok && r < rows2; r++) {
-                if(notequal(arr2.get(r, c), 1.0)) {
-                    ckout << "[" << CkMyPe() << "," << thisIndex << "] arr2 -- Illegal element at (" << r << "," << c << ") " << arr2.get(r,c) << endl;
+                if(notequal(r2.get(r, c), 1.0)) {
+                    ckout << "[" << CkMyPe() << "," << thisIndex << "] arr2 -- Illegal element at (" << r << "," << c << ") " << r2.get(r,c) << endl;
                     ok=false;
                     errors++;
                 }
@@ -316,8 +314,8 @@ protected:
             ok = true;
             for(unsigned int c = 0; ok && c < cols2; c++) {
                 for(unsigned int r = 0; ok && r < rows1; r++) {
-                    if(notequal(prod.get(r,c), 1.0 * cols1)) {
-                        ckout << "[" << CkMyPe() << "] result  -- Illegal element at (" << r << "," << c << ") " << prod.get(r,c) << endl;
+                    if(notequal(rp.get(r,c), 1.0 * cols1)) {
+                        ckout << "[" << CkMyPe() << "] result  -- Illegal element at (" << r << "," << c << ") " << rp.get(r,c) << endl;
                         ok=false;
                         errors++;
                     }
@@ -336,18 +334,23 @@ protected:
 
     // ============================= 1D ===================================
 
-    void FindProductNoPrefetch() {
+    void FindProductNoPrefetch(MSA2DRowMjr::Read &r1,
+                                                          MSA2DColMjr::Read &r2,
+                                                          MSA2DRowMjrC::Write &wp)
+       {
 #ifdef OLD
-        FindProductNoPrefetchNMK();
+        FindProductNoPrefetchNMK(r1, r2, wp);
 #else
-        FindProductNoPrefetchMKN_RM();
+        FindProductNoPrefetchMKN_RM(r1, r2, wp);
 #endif
     }
 
     // new, but bad perf
     // improved perf by taking the prod.accu out of the innermost loop, up 2
     // further improved perf by taking the arr1.get out of the innermost loop, up 1.
-    void FindProductNoPrefetchMKN_RM()
+    void FindProductNoPrefetchMKN_RM(MSA2DRowMjr::Read &r1,
+                                                                        MSA2DColMjr::Read &r2,
+                                                                        MSA2DRowMjrC::Write &wp)
     {
         CkAssert(arr2.getArrayLayout() == MSA_ROW_MAJOR);
 //         CkPrintf("reached\n");
@@ -359,25 +362,25 @@ protected:
             for(unsigned int c = 0; c < cols2; c++)
                 result[c] = 0;
             for(unsigned int k = 0; k < cols1; k++) { // K
-                double a = arr1.get(r,k);
+                double a = r1.get(r,k);
                 for(unsigned int c = 0; c < cols2; c++) { // N
-                    result[c] += a * arr2.get(k,c);
+                    result[c] += a * r2.get(k,c);
 //                     prod.set(r,c) = result; // @@ to see if accu is the delay
 //                     prod.accumulate(prod.getIndex(r,c), result);
                 }
 //              assert(!notequal(result, 1.0*cols1));
             }
             for(unsigned int c = 0; c < cols2; c++) {
-                prod.set(r,c) = result[c];
+                wp.set(r,c) = result[c];
             }
         }
         delete [] result;
-
-        prod.sync();
     }
 
     // old
-    void FindProductNoPrefetchNMK()
+    void FindProductNoPrefetchNMK(MSA2DRowMjr::Read &r1,
+                                                                 MSA2DColMjr::Read &r2,
+                                                                 MSA2DRowMjrC::Write &wp)
     {
         unsigned int rowStart, rowEnd;
         GetMyIndices(rows1, thisIndex, numWorkers, rowStart, rowEnd);
@@ -387,27 +390,25 @@ protected:
 
                 double result = 0.0;
                 for(unsigned int k = 0; k < cols1; k++) { // K
-                    double e1 = arr1.get(r,k);
-                    double e2 = arr2.get(k,c);
+                    double e1 = r1.get(r,k);
+                    double e2 = r2.get(k,c);
                     result += e1 * e2;
                 }
 //              assert(!notequal(result, 1.0*cols1));
 
-                prod.set(r,c) = result;
+                wp.set(r,c) = result;
             }
         }
-
-        prod.sync();
     }
 
     // Assumes that the nepp equals the size of a row, i.e. NEPP == COLS1 == ROWS2
-    void FindProductNoPrefetchStripMined()
+    void FindProductNoPrefetchStripMined(MSA2DRowMjrC::Write &wp)
     {
-        FindProductNoPrefetchStripMinedMKN_ROWMJR();
+        FindProductNoPrefetchStripMinedMKN_ROWMJR(wp);
     }
 
     // Assumes that the nepp equals the size of a row, i.e. NEPP == COLS1 == ROWS2
-    void FindProductNoPrefetchStripMinedNMK()
+    void FindProductNoPrefetchStripMinedNMK(MSA2DRowMjrC::Write &wp)
     {
         CkAssert(NEPP == cols1);
         unsigned int rowStart, rowEnd;
@@ -428,19 +429,17 @@ protected:
                 }
 //              assert(!notequal(result, 1.0*cols1));
 
-                prod.set(r,c) = result;
+                wp.set(r,c) = result;
             }
         }
-        double time2 = CmiWallTimer();
 
-        prod.sync();
-        double time3 = CmiWallTimer();
-        CkPrintf("timings %f %f\n", time2-time1, time3-time2);
+        double time2 = CmiWallTimer();
+        CkPrintf("timings %f \n", time2-time1);
     }
 
     // Assumes that the nepp equals the size of a row, i.e. NEPP == COLS1 == ROWS2
     // Assumes CkAssert(NEPP_C == cols2);
-    void FindProductNoPrefetchStripMinedMKN_ROWMJR()
+    void FindProductNoPrefetchStripMinedMKN_ROWMJR(MSA2DRowMjrC::Write &wp)
     {
         CkAssert(NEPP == cols1);
         CkAssert(NEPP_C == cols2);
@@ -452,7 +451,7 @@ protected:
         for(unsigned int r = rowStart; r <= rowEnd; r++) {  // M
             double* a = &(arr1.getPageBottom(arr1.getIndex(r,0),Read_Fault)); // ptr to row of A
             for(unsigned int c = 0; c < cols2; c++) { // N
-                prod.set(r,c);  // just mark it as updated, need a better way
+                wp.set(r,c);  // just mark it as updated, need a better way
             }
             double* cm = &(prod.getPageBottom(prod.getIndex(r,0),Write_Fault)); // ptr to row of C
             for(unsigned int k = 0; k < cols1; k++) { // K
@@ -464,11 +463,11 @@ protected:
             }
                        if (r%4==0) CthYield();
         }
-
-        prod.sync();
     }
 
-    void FindProductWithPrefetch()
+    void FindProductWithPrefetch(MSA2DRowMjr::Read &r1,
+                                                                MSA2DColMjr::Read &r2,
+                                                                MSA2DRowMjrC::Write &wp)
     {
         // fill in our portion of the array
         unsigned int rowStart, rowEnd;
@@ -484,14 +483,14 @@ protected:
         if(arr1.WaitAll())
         {
             if(verbose) ckout << thisIndex << ": Out of buffer in prefetch 1" << endl;
-            FindProductNoPrefetch();
+            FindProductNoPrefetch(r1, r2, wp);
             return;
         }
 
         if(arr2.WaitAll())
         {
             if(verbose) ckout << thisIndex << ": Out of buffer in prefetch 2" << endl;
-            FindProductNoPrefetch();
+            FindProductNoPrefetch(r1, r2, wp);
             return;
         }
 
@@ -504,24 +503,23 @@ protected:
                 double result = 0.0;
                 for(unsigned int k = 0; k < cols1; k++)
                 {
-                    double e1 = arr1.get2(r,k);
-                    double e2 = arr2.get2(k,c);
+                    double e1 = r1.get2(r,k);
+                    double e2 = r2.get2(k,c);
                     result += e1 * e2;
                 }
 
                 //ckout << "[" << r << "," << c << "] = " << result << endl;
 
-                prod.set(r,c) = result;
+                wp(r,c) = result;
             }
             //ckout << thisIndex << "." << endl;
         }
 
         //arr1.Unlock(); arr2.Unlock();
-        prod.sync();
     }
 
     // ============================= 2D ===================================
-    void FindProductNoPrefetch2DStripMined()
+    void FindProductNoPrefetch2DStripMined(MSA2DRowMjrC::Write &wp)
     {
         CkAssert(NEPP == cols1);
         unsigned int rowStart, rowEnd, colStart, colEnd;
@@ -543,14 +541,14 @@ protected:
                 }
 //              assert(!notequal(result, 1.0*cols1));
 
-                prod.set(r,c) = result;
+                wp.set(r,c) = result;
             }
         }
-
-        prod.sync();
     }
 
-    void FindProductNoPrefetch2D()
+    void FindProductNoPrefetch2D(MSA2DRowMjr::Read &r1,
+                                                                MSA2DColMjr::Read &r2,
+                                                                MSA2DRowMjrC::Write &wp)
     {
         unsigned int rowStart, rowEnd, colStart, colEnd;
         // fill in our portion of the C matrix
@@ -562,21 +560,21 @@ protected:
 
                 double result = 0.0;
                 for(unsigned int k = 0; k < cols1; k++) {
-                    double e1 = arr1.get(r,k);
-                    double e2 = arr2.get(k,c);
+                    double e1 = r1.get(r,k);
+                    double e2 = r2.get(k,c);
                     result += e1 * e2;
                 }
 //              assert(!notequal(result, 1.0*cols1));
 
-                prod.set(r,c) = result;
+                wp.set(r,c) = result;
             }
         }
-
-        prod.sync();
     }
 
     // ============================= 3D ===================================
-    void FindProductNoPrefetch3D()
+    void FindProductNoPrefetch3D(MSA2DRowMjr::Read &r1,
+                                                                MSA2DColMjr::Read &r2,
+                                                                MSA2DRowMjrC::Accum &ap)
     {
         unsigned int rowStart, rowEnd, colStart, colEnd, kStart, kEnd;
         // fill in our portion of the C matrix
@@ -589,20 +587,18 @@ protected:
 
                 double result = 0.0;
                 for(unsigned int k = kStart; k <= kEnd; k++) {
-                    double e1 = arr1.get(r,k);
-                    double e2 = arr2.get(k,c);
+                    double e1 = r1.get(r,k);
+                    double e2 = r2.get(k,c);
                     result += e1 * e2;
                 }
 //              assert(!notequal(result, 1.0*cols1));
 
-                prod.accumulate(prod.getIndex(r,c), result);
+                ap(r,c) += result;
             }
         }
-
-        prod.sync();
     }
 
-    void FindProductNoPrefetch3DStripMined()
+    void FindProductNoPrefetch3DStripMined(MSA2DRowMjrC::Accum &ap)
     {
         CkAssert(NEPP == cols1);
         unsigned int rowStart, rowEnd, colStart, colEnd, kStart, kEnd;
@@ -624,11 +620,9 @@ protected:
                 }
 //              assert(!notequal(result, 1.0*cols1));
 
-                prod.accumulate(prod.getIndex(r,c), result);
+                ap(r,c) += result;
             }
         }
-
-        prod.sync();
     }
 
     // ================================================================
@@ -660,24 +654,26 @@ public:
         times.push_back(CkWallTimer()); // 3
         description.push_back("   enroll");
 
+               MSA2DRowMjr::Write &w1 = arr1.getInitialWrite();
+               MSA2DColMjr::Write &w2 = arr2.getInitialWrite();
+
         if(verbose) ckout << thisIndex << ": filling" << endl;
         switch(DECOMPOSITION){
         case 1:
         case 3:
         case 4:
         case 6:
-            FillArrays();
+            FillArrays(w1, w2);
             break;
         case 2:
         case 5:
-            FillArrays2D();
+            FillArrays2D(w1, w2);
             break;
         }
         times.push_back(CkWallTimer()); // 4
         description.push_back("  fill");
 
         if(verbose) ckout << thisIndex << ": syncing" << endl;
-        SyncArrays();
         times.push_back(CkWallTimer()); // 5
         description.push_back("    sync");
 
@@ -685,32 +681,43 @@ public:
 
         if(verbose) ckout << thisIndex << ": product" << endl;
 
+               MSA2DRowMjr::Read &r1 = arr1.syncToRead(w1);
+               MSA2DColMjr::Read &r2 = arr2.syncToRead(w2);
+
+               hp = &(prod.getInitialWrite());
+               MSA2DRowMjrC::Write &wp = * (MSA2DRowMjrC::Write *) hp;
+               MSA2DRowMjrC::Accum &ap = * (MSA2DRowMjrC::Accum *) hp;
+
         switch(DECOMPOSITION) {
         case 1:
             if (runPrefetchVersion)
-                FindProductWithPrefetch();
+                FindProductWithPrefetch(r1, r2, wp);
             else
-                FindProductNoPrefetch();
+                FindProductNoPrefetch(r1, r2, wp);
             break;
         case 2:
-            FindProductNoPrefetch2D();
+            FindProductNoPrefetch2D(r1, r2, wp);
             break;
         case 3:
-            FindProductNoPrefetch3D();
+            FindProductNoPrefetch3D(r1, r2, ap);
             break;
         case 4:
-            FindProductNoPrefetchStripMined();
+            FindProductNoPrefetchStripMined(wp);
             break;
         case 5:
-            FindProductNoPrefetch2DStripMined();
+            FindProductNoPrefetch2DStripMined(wp);
             break;
         case 6:
-            FindProductNoPrefetch3DStripMined();
+            FindProductNoPrefetch3DStripMined(ap);
             break;
         }
         times.push_back(CkWallTimer()); // 6
         description.push_back("    work");
 
+               h1 = &r1;
+               h2 = &r2;
+               hp = &(prod.syncToRead(*hp));
+
         Contribute();
     }
 
@@ -721,7 +728,7 @@ public:
         description.push_back("    redn");
 
         if(verbose) ckout << thisIndex << ": testing" << endl;
-        if (do_test) TestResults();
+        if (do_test) TestResults(*h1, *h2, * (MSA2DRowMjrC::Read *) hp);
         times.push_back(CkWallTimer()); // 5
         description.push_back("    test");
         Contribute();