Example: Simple Charm++ matrix-matrix multiply
[charm.git] / examples / charm++ / matmul / matmul.ci
1 mainmodule matmul {
2   readonly CProxy_Main mainProxy;
3   mainchare Main {
4     entry Main(CkArgMsg *m);
5     entry [reductiontarget] void done();
6   };
7
8   array [2D] Block {
9     entry Block(unsigned int blockSize, unsigned int numBlocks);
10     entry void pdgemmSendInput(CProxy_Block output, bool aOrB) {
11       atomic {
12         if (aOrB)
13           output[thisIndex].inputA(thisIndex.x, data, blockSize, true);
14         else
15           output[thisIndex].inputB(thisIndex.y, data, blockSize, true);
16       }
17     };
18
19     entry void pdgemmRun(double alpha, double beta, CkCallback done) {
20       forall [block] (0:numBlocks-1,1) {
21         when
22           inputA[block](int blockIdA, double blockA[blockSizeA*blockSizeA],
23                         unsigned int blockSizeA, bool fromSourceA),
24           inputB[block](int blockIdB, double blockB[blockSizeB*blockSizeB],
25                         unsigned int blockSizeB, bool fromSourceB) atomic {
26           CkAssert(blockSizeA == blockSizeB);
27           cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
28                       blockSize, blockSize, blockSize,
29                       alpha,
30                       blockA, blockSize, blockB, blockSize,
31                       beta, data, blockSize);
32           if (fromSourceA || ((blockIdA + numBlocks - 1) % numBlocks != thisIndex.x)) {
33             int destX = (thisIndex.x + 1) % numBlocks;
34             int destY = thisIndex.y;
35             thisProxy(destX, destY).inputA(blockIdA, blockA, blockSizeA);
36           }
37           if (fromSourceB || ((blockIdB + numBlocks - 1) % numBlocks != thisIndex.y)) {
38             int destX = thisIndex.x;
39             int destY = (thisIndex.y + 1) % numBlocks;
40             thisProxy(destX, destY).inputB(blockIdB, blockB, blockSizeB);
41           }
42         }
43       }
44       atomic {
45         contribute(done);
46       }
47     };
48     entry void inputA(int blockIdA, double blockA[blockSizeA*blockSizeA], unsigned int blockSizeA, bool fromSource = false);
49     entry void inputB(int blockIdB, double blockB[blockSizeB*blockSizeB], unsigned int blockSizeB, bool fromSource = false);
50   };
51 };