Example charm++/matmul: simplify logic for when to pass blocks
[charm.git] / examples / charm++ / matmul / matmul.ci
1 mainmodule matmul {
2   readonly CProxy_Main mainProxy;
3   mainchare Main {
4     entry Main(CkArgMsg *m);
5     entry [reductiontarget] void done();
6   };
7
8   array [2D] Block {
9     entry Block(unsigned int blockSize, unsigned int numBlocks);
10     entry void pdgemmSendInput(CProxy_Block output, bool aOrB) {
11       atomic {
12         if (aOrB)
13           output[thisIndex].inputA(0, data, blockSize);
14         else
15           output[thisIndex].inputB(0, data, blockSize);
16       }
17     };
18
19     entry void pdgemmRun(double alpha, double beta, CkCallback done) {
20       forall [block] (0:numBlocks-1,1) {
21         when
22           inputA[block](int blockIdA, double blockA[blockSizeA*blockSizeA],
23                         unsigned int blockSizeA),
24           inputB[block](int blockIdB, double blockB[blockSizeB*blockSizeB],
25                         unsigned int blockSizeB) atomic {
26           CkAssert(blockSizeA == blockSizeB);
27           cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
28                       blockSize, blockSize, blockSize,
29                       alpha,
30                       blockA, blockSize, blockB, blockSize,
31                       beta, data, blockSize);
32           if (blockIdA != numBlocks) {
33             int destX = (thisIndex.x + 1) % numBlocks;
34             int destY = thisIndex.y;
35             thisProxy(destX, destY).inputA(blockIdA+1, blockA, blockSizeA);
36           }
37           if (blockIdB != numBlocks) {
38             int destX = thisIndex.x;
39             int destY = (thisIndex.y + 1) % numBlocks;
40             thisProxy(destX, destY).inputB(blockIdB+1, blockB, blockSizeB);
41           }
42         }
43       }
44       atomic {
45         contribute(done);
46       }
47     };
48     entry void inputA(int blockIdA, double blockA[blockSizeA*blockSizeA], unsigned int blockSizeA);
49     entry void inputB(int blockIdB, double blockB[blockSizeB*blockSizeB], unsigned int blockSizeB);
50   };
51 };