Example charm++/matmul: generalize implementation body to varying sizes
authorPhil Miller <mille121@illinois.edu>
Tue, 8 May 2012 19:40:15 +0000 (14:40 -0500)
committerPhil Miller <mille121@illinois.edu>
Tue, 8 May 2012 19:43:58 +0000 (14:43 -0500)
examples/charm++/matmul/matmul.ci

index 7b8696afd9c584491683db7d145a13e8e7b0ca87..10e378992ca60b35936092234a948e3549834843 100644 (file)
@@ -10,34 +10,35 @@ mainmodule matmul {
     entry void pdgemmSendInput(CProxy_Block output, bool aOrB) {
       atomic {
         if (aOrB)
-          output[thisIndex].inputA(0, data, blockSize);
+          output[thisIndex].inputA(0, data, blockSize, blockSize);
         else
-          output[thisIndex].inputB(0, data, blockSize);
+          output[thisIndex].inputB(0, data, blockSize, blockSize);
       }
     };
 
     entry void pdgemmRun(double alpha, double beta, CkCallback done) {
       forall [block] (0:numBlocks-1,1) {
         when
-          inputA[block](int blockIdA, double blockA[blockSizeA*blockSizeA],
-                        unsigned int blockSizeA),
-          inputB[block](int blockIdB, double blockB[blockSizeB*blockSizeB],
-                        unsigned int blockSizeB) atomic {
-          CkAssert(blockSizeA == blockSizeB);
+          inputA[block](int blockIdA, double blockA[M*KA], unsigned int M, unsigned int KA),
+          inputB[block](int blockIdB, double blockB[KB*N], unsigned int KB, unsigned int N)
+          atomic {
+          CkAssert(KA == KB);
+
           cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
-                      blockSize, blockSize, blockSize,
+                      M, N, KA,
                       alpha,
-                      blockA, blockSize, blockB, blockSize,
-                      beta, data, blockSize);
+                      blockA, KA, blockB, N,
+                      beta, data, N);
+
           if (blockIdA != numBlocks) {
             int destX = (thisIndex.x + 1) % numBlocks;
             int destY = thisIndex.y;
-            thisProxy(destX, destY).inputA(blockIdA+1, blockA, blockSizeA);
+            thisProxy(destX, destY).inputA(blockIdA+1, blockA, M, KA);
           }
           if (blockIdB != numBlocks) {
             int destX = thisIndex.x;
             int destY = (thisIndex.y + 1) % numBlocks;
-            thisProxy(destX, destY).inputB(blockIdB+1, blockB, blockSizeB);
+            thisProxy(destX, destY).inputB(blockIdB+1, blockB, KB, N);
           }
         }
       }
@@ -45,7 +46,7 @@ mainmodule matmul {
         contribute(done);
       }
     };
-    entry void inputA(int blockIdA, double blockA[blockSizeA*blockSizeA], unsigned int blockSizeA);
-    entry void inputB(int blockIdB, double blockB[blockSizeB*blockSizeB], unsigned int blockSizeB);
+    entry void inputA(int blockIdA, double blockA[M*KA], unsigned int M, unsigned int KA);
+    entry void inputB(int blockIdB, double blockB[KB*N], unsigned int KB, unsigned int N);
   };
 };